WSLINMSAI commited on
Commit
aa8cdfb
·
verified ·
1 Parent(s): e4ac2fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -71
app.py CHANGED
@@ -1,24 +1,21 @@
1
  import random
2
  import gradio as gr
3
- import torch
4
- from transformers import pipeline
5
 
6
  # -----------------------------
7
- # 1. Load the Question Generation Pipeline
 
8
  # -----------------------------
9
- # Model reference: https://huggingface.co/valhalla/t5-small-qg-hl
 
10
  qg_pipeline = pipeline(
11
  "text2text-generation",
12
- model="valhalla/t5-small-qg-hl",
13
- tokenizer="valhalla/t5-small-qg-hl",
14
- use_fast=False # disable the fast tokenizer to avoid tiktoken conversion issues
15
  )
16
 
17
  # -----------------------------
18
- # 2. Prepare Passages (with <hl> tags)
19
- # We'll store a few short passages at different difficulty levels.
20
- # The <hl> tags highlight the answer, which helps the QG pipeline
21
- # know which part to form a question about.
22
  # -----------------------------
23
  passages = {
24
  "easy": [
@@ -39,12 +36,7 @@ passages = {
39
  }
40
 
41
  # -----------------------------
42
- # 3. Session State Setup
43
- # We'll track:
44
- # - difficulty: "easy", "medium", or "hard"
45
- # - score: integer, increments when correct, decrements when wrong
46
- # - question: the last generated question text
47
- # - answer: the correct answer for the last question
48
  # -----------------------------
49
  def init_state():
50
  return {
@@ -55,62 +47,50 @@ def init_state():
55
  }
56
 
57
  # -----------------------------
58
- # 4. Difficulty Adjustment Rules
59
- # - If score >= +2, increase difficulty (up to "hard").
60
- # - If score <= -2, decrease difficulty (down to "easy").
61
  # -----------------------------
62
  def adjust_difficulty(state):
63
  diff_order = ["easy", "medium", "hard"]
64
  idx = diff_order.index(state["difficulty"])
65
 
66
  if state["score"] >= 2 and idx < len(diff_order) - 1:
67
- # Increase difficulty
68
  state["difficulty"] = diff_order[idx + 1]
69
- state["score"] = 0 # reset score so we can re-check for next threshold
70
  return "Difficulty increased to: " + state["difficulty"]
71
  elif state["score"] <= -2 and idx > 0:
72
- # Decrease difficulty
73
  state["difficulty"] = diff_order[idx - 1]
74
- state["score"] = 0 # reset score
75
  return "Difficulty decreased to: " + state["difficulty"]
76
  else:
77
- return f"Difficulty remains: {state['difficulty']}"
78
 
79
  # -----------------------------
80
  # 5. Generate a Question from a Passage
81
- # We'll pick a random passage from the current difficulty level,
82
- # pass it to the QG pipeline, and store the result in the state.
83
  # -----------------------------
84
  def generate_question(state):
85
- # Pick a random passage from the current difficulty
86
  passage_list = passages[state["difficulty"]]
87
  chosen_passage = random.choice(passage_list)
88
 
89
- # The correct answer is the text between <hl> tags
90
- # We'll extract it to compare with user input later
91
- # Example: "The capital of <hl>France<hl> is Paris." -> answer = "France"
92
  parts = chosen_passage.split("<hl>")
93
  if len(parts) == 3:
94
- # e.g., ["The capital of ", "France", " is Paris."]
95
  answer = parts[1].strip()
96
  else:
97
  answer = "N/A"
98
 
99
- # The QG pipeline expects the passage with <hl> tokens
100
- # We'll feed that directly
101
  result = qg_pipeline(chosen_passage, max_length=64)
102
  question_text = result[0]["generated_text"]
103
 
104
- # Update state
105
  state["question"] = question_text
106
  state["answer"] = answer
107
 
108
- # Return question to display
109
  return question_text
110
 
111
  # -----------------------------
112
  # 6. Check the User's Answer
113
- # We do a simple string comparison (case-insensitive).
114
  # -----------------------------
115
  def check_answer(state, user_answer):
116
  correct_answer = state["answer"].lower().strip()
@@ -118,77 +98,59 @@ def check_answer(state, user_answer):
118
 
119
  if user_answer_clean == correct_answer:
120
  state["score"] += 1
121
- result = "Correct!"
122
  else:
123
  state["score"] -= 1
124
- result = f"Incorrect! The correct answer was: {state['answer']}"
125
 
126
- # Possibly adjust difficulty
127
  difficulty_update = adjust_difficulty(state)
128
- return result + "\n" + difficulty_update
129
 
130
  # -----------------------------
131
- # 7. Gradio Interface
132
- # We'll build a small flow:
133
- # - Show current difficulty
134
- # - "Generate Question" button
135
- # - Show question
136
- # - Text input for answer
137
- # - "Submit Answer" button
138
- # - Show result
139
  # -----------------------------
140
  with gr.Blocks() as demo:
141
- # We need a state object that persists across user interactions
142
  state = gr.State(init_state())
143
 
144
  gr.Markdown("# Adaptive Language Tutor")
145
- gr.Markdown("This demo uses a T5-based model to generate questions. "
146
- "Difficulty automatically adjusts based on your answers.")
 
 
147
 
148
- # Display current difficulty
149
  difficulty_label = gr.Markdown("**Difficulty**: (will be updated)")
150
 
151
- # Button + output area for generating question
152
  with gr.Row():
153
  generate_button = gr.Button("Generate Question")
154
  question_output = gr.Textbox(label="Question", interactive=False)
155
 
156
- # Text input + button to submit answer
157
  user_answer = gr.Textbox(label="Your Answer")
158
  submit_button = gr.Button("Submit Answer")
159
-
160
- # Result output
161
  result_output = gr.Textbox(label="Result", interactive=False)
162
 
163
- # -- Define event functions --
164
  def update_difficulty_label(state):
165
  return f"**Difficulty**: {state['difficulty']} (Score: {state['score']})"
166
 
167
- # 1) On load, update difficulty label
168
  demo.load(fn=update_difficulty_label, inputs=state, outputs=difficulty_label)
169
 
170
- # 2) Generate question
171
  def on_generate_question(state):
172
  question = generate_question(state)
173
  difficulty_text = update_difficulty_label(state)
174
  return question, difficulty_text
175
 
176
- generate_button.click(
177
- fn=on_generate_question,
178
- inputs=state,
179
- outputs=[question_output, difficulty_label]
180
- )
181
 
182
- # 3) Submit answer
183
  def on_submit_answer(user_answer, state):
184
  feedback = check_answer(state, user_answer)
185
  difficulty_text = update_difficulty_label(state)
186
  return feedback, difficulty_text
187
 
188
- submit_button.click(
189
- fn=on_submit_answer,
190
- inputs=[user_answer, state],
191
- outputs=[result_output, difficulty_label]
192
- )
193
 
194
  demo.launch()
 
1
  import random
2
  import gradio as gr
3
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
 
4
 
5
  # -----------------------------
6
+ # 1. Load the Model & Slow Tokenizer
7
+ # We explicitly disable the fast tokenizer by setting use_fast=False.
8
  # -----------------------------
9
+ tokenizer = T5Tokenizer.from_pretrained("valhalla/t5-small-qg-hl", use_fast=False)
10
+ model = T5ForConditionalGeneration.from_pretrained("valhalla/t5-small-qg-hl")
11
  qg_pipeline = pipeline(
12
  "text2text-generation",
13
+ model=model,
14
+ tokenizer=tokenizer
 
15
  )
16
 
17
  # -----------------------------
18
+ # 2. Define Passages by Difficulty
 
 
 
19
  # -----------------------------
20
  passages = {
21
  "easy": [
 
36
  }
37
 
38
  # -----------------------------
39
+ # 3. Session State Initialization
 
 
 
 
 
40
  # -----------------------------
41
  def init_state():
42
  return {
 
47
  }
48
 
49
  # -----------------------------
50
+ # 4. Adjust Difficulty Based on Score
 
 
51
  # -----------------------------
52
  def adjust_difficulty(state):
53
  diff_order = ["easy", "medium", "hard"]
54
  idx = diff_order.index(state["difficulty"])
55
 
56
  if state["score"] >= 2 and idx < len(diff_order) - 1:
 
57
  state["difficulty"] = diff_order[idx + 1]
58
+ state["score"] = 0 # Reset score upon difficulty change
59
  return "Difficulty increased to: " + state["difficulty"]
60
  elif state["score"] <= -2 and idx > 0:
 
61
  state["difficulty"] = diff_order[idx - 1]
62
+ state["score"] = 0 # Reset score upon difficulty change
63
  return "Difficulty decreased to: " + state["difficulty"]
64
  else:
65
+ return f"Difficulty remains: {state['difficulty']} (Score: {state['score']})"
66
 
67
  # -----------------------------
68
  # 5. Generate a Question from a Passage
 
 
69
  # -----------------------------
70
  def generate_question(state):
71
+ # Select a random passage from the current difficulty level
72
  passage_list = passages[state["difficulty"]]
73
  chosen_passage = random.choice(passage_list)
74
 
75
+ # Extract the answer from text between <hl> tags.
 
 
76
  parts = chosen_passage.split("<hl>")
77
  if len(parts) == 3:
 
78
  answer = parts[1].strip()
79
  else:
80
  answer = "N/A"
81
 
82
+ # Generate a question using the QG pipeline.
 
83
  result = qg_pipeline(chosen_passage, max_length=64)
84
  question_text = result[0]["generated_text"]
85
 
86
+ # Update the state with the generated question and correct answer.
87
  state["question"] = question_text
88
  state["answer"] = answer
89
 
 
90
  return question_text
91
 
92
  # -----------------------------
93
  # 6. Check the User's Answer
 
94
  # -----------------------------
95
  def check_answer(state, user_answer):
96
  correct_answer = state["answer"].lower().strip()
 
98
 
99
  if user_answer_clean == correct_answer:
100
  state["score"] += 1
101
+ result_text = "Correct!"
102
  else:
103
  state["score"] -= 1
104
+ result_text = f"Incorrect! The correct answer was: {state['answer']}"
105
 
106
+ # Adjust the difficulty based on updated score.
107
  difficulty_update = adjust_difficulty(state)
108
+ return result_text + "\n" + difficulty_update
109
 
110
  # -----------------------------
111
+ # 7. Build the Gradio Interface
 
 
 
 
 
 
 
112
  # -----------------------------
113
  with gr.Blocks() as demo:
114
+ # Persistent state for the session.
115
  state = gr.State(init_state())
116
 
117
  gr.Markdown("# Adaptive Language Tutor")
118
+ gr.Markdown(
119
+ "This demo uses a T5-based model to generate questions from a passage. "
120
+ "Difficulty will automatically adjust based on your performance."
121
+ )
122
 
123
+ # Display current difficulty and score.
124
  difficulty_label = gr.Markdown("**Difficulty**: (will be updated)")
125
 
 
126
  with gr.Row():
127
  generate_button = gr.Button("Generate Question")
128
  question_output = gr.Textbox(label="Question", interactive=False)
129
 
 
130
  user_answer = gr.Textbox(label="Your Answer")
131
  submit_button = gr.Button("Submit Answer")
 
 
132
  result_output = gr.Textbox(label="Result", interactive=False)
133
 
 
134
  def update_difficulty_label(state):
135
  return f"**Difficulty**: {state['difficulty']} (Score: {state['score']})"
136
 
137
+ # Update the difficulty label when the interface loads.
138
  demo.load(fn=update_difficulty_label, inputs=state, outputs=difficulty_label)
139
 
140
+ # Event: Generate a new question.
141
  def on_generate_question(state):
142
  question = generate_question(state)
143
  difficulty_text = update_difficulty_label(state)
144
  return question, difficulty_text
145
 
146
+ generate_button.click(fn=on_generate_question, inputs=state, outputs=[question_output, difficulty_label])
 
 
 
 
147
 
148
+ # Event: Submit the answer and check correctness.
149
  def on_submit_answer(user_answer, state):
150
  feedback = check_answer(state, user_answer)
151
  difficulty_text = update_difficulty_label(state)
152
  return feedback, difficulty_text
153
 
154
+ submit_button.click(fn=on_submit_answer, inputs=[user_answer, state], outputs=[result_output, difficulty_label])
 
 
 
 
155
 
156
  demo.launch()