Update project_model.py
Browse files- project_model.py +19 -5
project_model.py
CHANGED
|
@@ -77,11 +77,22 @@ class VisualQAState:
|
|
| 77 |
|
| 78 |
def add_question(self, question: str):
|
| 79 |
"""
|
| 80 |
-
Adds a follow-up
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
"""
|
| 82 |
self.message_history.append({
|
| 83 |
-
"role": "
|
| 84 |
-
"content": [{"type": "text", "text":
|
| 85 |
})
|
| 86 |
|
| 87 |
# -------------------------------
|
|
@@ -171,13 +182,16 @@ def process_inputs(
|
|
| 171 |
audio_text = whisper_pipe(audio_path)["text"]
|
| 172 |
question += " " + audio_text
|
| 173 |
|
| 174 |
-
# Append question to conversation history
|
| 175 |
session.add_question(question)
|
| 176 |
|
| 177 |
# Generate response using GEMMA with full conversation history
|
| 178 |
gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
|
| 179 |
answer = gemma_output[0]["generated_text"][-1]["content"]
|
| 180 |
|
|
|
|
|
|
|
|
|
|
| 181 |
# If TTS is enabled, synthesize answer as speech
|
| 182 |
output_audio_path = "response.wav"
|
| 183 |
if enable_tts:
|
|
@@ -185,4 +199,4 @@ def process_inputs(
|
|
| 185 |
else:
|
| 186 |
output_audio_path = None
|
| 187 |
|
| 188 |
-
return answer, output_audio_path
|
|
|
|
| 77 |
|
| 78 |
def add_question(self, question: str):
|
| 79 |
"""
|
| 80 |
+
Adds a follow-up question only if the last message was from assistant.
|
| 81 |
+
Ensures alternating user/assistant messages.
|
| 82 |
+
"""
|
| 83 |
+
if not self.message_history or self.message_history[-1]["role"] == "assistant":
|
| 84 |
+
self.message_history.append({
|
| 85 |
+
"role": "user",
|
| 86 |
+
"content": [{"type": "text", "text": question}]
|
| 87 |
+
})
|
| 88 |
+
|
| 89 |
+
def add_answer(self, answer: str):
|
| 90 |
+
"""
|
| 91 |
+
Appends the assistant's response to the conversation history.
|
| 92 |
"""
|
| 93 |
self.message_history.append({
|
| 94 |
+
"role": "assistant",
|
| 95 |
+
"content": [{"type": "text", "text": answer}]
|
| 96 |
})
|
| 97 |
|
| 98 |
# -------------------------------
|
|
|
|
| 182 |
audio_text = whisper_pipe(audio_path)["text"]
|
| 183 |
question += " " + audio_text
|
| 184 |
|
| 185 |
+
# Append question to conversation history (only if alternating correctly)
|
| 186 |
session.add_question(question)
|
| 187 |
|
| 188 |
# Generate response using GEMMA with full conversation history
|
| 189 |
gemma_output = gemma_pipe(text=session.message_history, max_new_tokens=200)
|
| 190 |
answer = gemma_output[0]["generated_text"][-1]["content"]
|
| 191 |
|
| 192 |
+
# Append GEMMA's response to the history to maintain alternating structure
|
| 193 |
+
session.add_answer(answer)
|
| 194 |
+
|
| 195 |
# If TTS is enabled, synthesize answer as speech
|
| 196 |
output_audio_path = "response.wav"
|
| 197 |
if enable_tts:
|
|
|
|
| 199 |
else:
|
| 200 |
output_audio_path = None
|
| 201 |
|
| 202 |
+
return answer, output_audio_path
|