jdesiree commited on
Commit
766d00f
·
verified ·
1 Parent(s): 88b796b

Major Update

Browse files

Changed to an Any-to-Any model, Qwen2.5-Omni.
- Removed mock-streaming, as Omni has this feature.
- Made necessary adjustments to implement Omni
- Added toggle button for voice responses
- Removed now unused imports
- Increased Token limit to match smart truncate
- Reinforced JSON file format for Graph tool

Files changed (1) hide show
  1. app.py +148 -61
app.py CHANGED
@@ -9,12 +9,16 @@ import re
9
  import requests
10
  from langchain.tools import BaseTool
11
  from langchain.agents import initialize_agent, AgentType
12
- from langchain_community.llms import HuggingFaceHub
13
  from langchain.memory import ConversationBufferWindowMemory
14
- from langchain.prompts import PromptTemplate
15
- from langchain.schema import SystemMessage, HumanMessage, AIMessage
 
16
  from pydantic import BaseModel, Field
17
- from typing import Type, Optional
 
 
 
 
18
 
19
  # --- Environment and Logging Setup ---
20
  logging.basicConfig(level=logging.INFO)
@@ -30,7 +34,7 @@ metrics_tracker = EduBotMetrics(save_file="edu_metrics.json")
30
  # --- LangChain Tool Definition ---
31
  class GraphInput(BaseModel):
32
  data_json: str = Field(description="JSON string of data for the graph")
33
- labels_json: str = Field(description="JSON string of labels for the graph")
34
  plot_type: str = Field(description="Type of plot: bar, line, or pie")
35
  title: str = Field(description="Title for the graph")
36
  x_label: str = Field(description="X-axis label", default="")
@@ -38,12 +42,42 @@ class GraphInput(BaseModel):
38
 
39
  class CreateGraphTool(BaseTool):
40
  name: str = "create_graph"
41
- description: str = """Generates a plot (bar, line, or pie) and returns it as an HTML-formatted Base64-encoded image string. Use this tool when teaching concepts that benefit from visual representation, such as: statistical distributions, mathematical functions, data comparisons, survey results, grade analyses, scientific relationships, economic models, or any quantitative information that would be clearer with a graph. The data and labels arguments must be JSON-encoded strings."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  args_schema: Type[BaseModel] = GraphInput
43
 
44
- def _run(self, data_json: str, labels_json: str, plot_type: str,
45
- title: str, x_label: str = "", y_label: str = "") -> str:
46
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return generate_plot(
48
  data_json=data_json,
49
  labels_json=labels_json,
@@ -55,6 +89,7 @@ class CreateGraphTool(BaseTool):
55
  except Exception as e:
56
  return f"<p style='color:red;'>Error creating graph: {str(e)}</p>"
57
 
 
58
  # --- System Prompt ---
59
  SYSTEM_PROMPT = """You are EduBot, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
60
 
@@ -126,32 +161,51 @@ def initialize_system_prompt(agent):
126
  agent.memory.chat_memory.add_message(system_message)
127
  system_prompt_initialized = True
128
 
129
- def create_langchain_agent():
130
- """Initialize LangChain agent with tools and memory."""
 
131
 
132
- # Initialize LLM
133
- llm = HuggingFaceHub(
134
- repo_id="Qwen/Qwen2.5-VL-7B-Instruct",
135
- huggingfacehub_api_token=hf_token,
136
- model_kwargs={
137
- "temperature": 0.7,
138
- "max_new_tokens": 1000,
139
- "top_p": 0.9,
140
- "return_full_text": False
141
- }
142
- )
143
 
144
- # Initialize tools
145
- tools = [CreateGraphTool()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- # Initialize memory
 
 
 
 
 
148
  memory = ConversationBufferWindowMemory(
149
  memory_key="chat_history",
150
  k=10,
151
  return_messages=True
152
  )
153
 
154
- # Create agent WITHOUT system prompt in prefix (we'll add it to memory instead)
155
  agent = initialize_agent(
156
  tools=tools,
157
  llm=llm,
@@ -174,6 +228,52 @@ def get_agent():
174
  agent = create_langchain_agent()
175
  return agent
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # --- UI: MathJax Configuration ---
178
  mathjax_config = '''
179
  <script>
@@ -271,7 +371,6 @@ def chat_response(message, history=None):
271
  logger.info(f"Message type: {type(message)}")
272
  logger.info(f"Message content: {message}")
273
 
274
- # This line might be causing the issue
275
  try:
276
  metrics_tracker.log_interaction(message, "user_query", "chat_start")
277
  logger.info("Metrics interaction logged successfully")
@@ -308,36 +407,21 @@ def chat_response(message, history=None):
308
  logger.error(f"Full traceback: {traceback.format_exc()}")
309
  return f"I apologize, but I encountered an error while processing your message: {str(e)}"
310
 
311
- def respond_with_enhanced_streaming(message, history=None):
312
- """Enhanced streaming response function."""
313
- try:
314
- response = chat_response(message)
315
- yield response
316
- except Exception as e:
317
- logger.error(f"Error in streaming response: {e}")
318
- yield f"I apologize, but I encountered an error: {str(e)}"
319
-
320
- # --- UI: Event Handlers ---
321
- def respond_and_update(message, history):
322
  """Main function to handle user submission."""
323
  if not message.strip():
324
- return history, ""
325
 
326
  # Add user message to history
327
  history.append({"role": "user", "content": message})
328
- # Yield history to show the user message immediately, and clear the textbox
329
- yield history, ""
330
-
331
- # Stream the bot's response
332
- full_response = ""
333
- for response_chunk in respond_with_enhanced_streaming(message, history):
334
- full_response = response_chunk
335
- # Update the last message (bot's response)
336
- if len(history) > 0 and history[-1]["role"] == "user":
337
- history.append({"role": "assistant", "content": full_response})
338
- else:
339
- history[-1] = {"role": "assistant", "content": full_response}
340
- yield history, ""
341
 
342
  def clear_chat():
343
  """Clear the chat history and reset system prompt flag."""
@@ -347,7 +431,6 @@ def clear_chat():
347
  system_prompt_initialized = False
348
  return [], ""
349
 
350
-
351
  # --- UI: Interface Creation ---
352
  def create_interface():
353
  """Creates and configures the complete Gradio interface."""
@@ -358,9 +441,9 @@ def create_interface():
358
  with open("styles.css", "r", encoding="utf-8") as css_file:
359
  custom_css = css_file.read()
360
  except FileNotFoundError:
361
- logger.warning("style.css file not found, using default styling")
362
  except Exception as e:
363
- logger.warning(f"Error reading style.css: {e}")
364
 
365
  with gr.Blocks(
366
  title="EduBot",
@@ -405,14 +488,18 @@ def create_interface():
405
  with gr.Column(elem_classes=["button-column"], scale=1):
406
  send = gr.Button("Send", elem_classes=["send-button"], size="sm")
407
  clear = gr.Button("Clear", elem_classes=["clear-button"], size="sm")
408
-
409
- # Set up event handlers
410
- msg.submit(respond_and_update, [msg, chatbot], [chatbot, msg])
411
- send.click(respond_and_update, [msg, chatbot], [chatbot, msg])
412
- clear.click(clear_chat, outputs=[chatbot, msg])
413
 
414
- # Apply CSS at the very end for highest precedence
415
- gr.HTML(f'<style>{custom_css}</style>')
 
 
 
 
 
 
 
 
416
 
417
  return demo
418
 
 
9
  import requests
10
  from langchain.tools import BaseTool
11
  from langchain.agents import initialize_agent, AgentType
 
12
  from langchain.memory import ConversationBufferWindowMemory
13
+ from langchain.schema import SystemMessage
14
+ from langchain.llms.base import LLM
15
+ from typing import Optional, List, Any, Type
16
  from pydantic import BaseModel, Field
17
+ from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
18
+ from qwen_omni_utils import process_mm_info
19
+ import soundfile as sf
20
+ import atexit
21
+ import glob
22
 
23
  # --- Environment and Logging Setup ---
24
  logging.basicConfig(level=logging.INFO)
 
34
  # --- LangChain Tool Definition ---
35
  class GraphInput(BaseModel):
36
  data_json: str = Field(description="JSON string of data for the graph")
37
+ labels_json: str = Field(description="JSON string of labels for the graph", default="[]")
38
  plot_type: str = Field(description="Type of plot: bar, line, or pie")
39
  title: str = Field(description="Title for the graph")
40
  x_label: str = Field(description="X-axis label", default="")
 
42
 
43
  class CreateGraphTool(BaseTool):
44
  name: str = "create_graph"
45
+ description: str = """Generates a plot (bar, line, or pie) and returns it as an HTML-formatted Base64-encoded image string. Use this tool when teaching concepts that benefit from visual representation, such as: statistical distributions, mathematical functions, data comparisons, survey results, grade analyses, scientific relationships, economic models, or any quantitative information that would be clearer with a graph.
46
+
47
+ REQUIRED FORMAT:
48
+ - data_json: A JSON dictionary where keys are category names and values are numbers
49
+ Example: '{"Math": 85, "Science": 92, "English": 78}'
50
+ - labels_json: A JSON list, only needed for pie charts if you want custom labels different from the data keys. For bar/line charts, use empty list: '[]'
51
+ Example for pie: '["Mathematics", "Science", "English Literature"]'
52
+ Example for bar/line: '[]'
53
+
54
+ EXAMPLES:
55
+ Bar chart: data_json='{"Q1": 1000, "Q2": 1200, "Q3": 950}', labels_json='[]'
56
+ Line chart: data_json='{"Jan": 100, "Feb": 120, "Mar": 110}', labels_json='[]'
57
+ Pie chart: data_json='{"A": 30, "B": 45, "C": 25}', labels_json='["Category A", "Category B", "Category C"]'
58
+
59
+ Always use proper JSON formatting with quotes around keys and string values."""
60
  args_schema: Type[BaseModel] = GraphInput
61
 
62
+ def _run(self, data_json: str, labels_json: str = "[]", plot_type: str = "bar",
63
+ title: str = "Chart", x_label: str = "", y_label: str = "") -> str:
64
  try:
65
+ # Validate JSON format before passing to generate_plot
66
+ import json
67
+ try:
68
+ data_parsed = json.loads(data_json)
69
+ labels_parsed = json.loads(labels_json)
70
+
71
+ # Validate data structure
72
+ if not isinstance(data_parsed, dict):
73
+ return "<p style='color:red;'>data_json must be a JSON dictionary with string keys and numeric values.</p>"
74
+
75
+ if not isinstance(labels_parsed, list):
76
+ return "<p style='color:red;'>labels_json must be a JSON list (use [] if no custom labels needed).</p>"
77
+
78
+ except json.JSONDecodeError as json_error:
79
+ return f"<p style='color:red;'>Invalid JSON format: {str(json_error)}. Ensure proper JSON formatting with quotes.</p>"
80
+
81
  return generate_plot(
82
  data_json=data_json,
83
  labels_json=labels_json,
 
89
  except Exception as e:
90
  return f"<p style='color:red;'>Error creating graph: {str(e)}</p>"
91
 
92
+
93
  # --- System Prompt ---
94
  SYSTEM_PROMPT = """You are EduBot, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
95
 
 
161
  agent.memory.chat_memory.add_message(system_message)
162
  system_prompt_initialized = True
163
 
164
+ class Qwen25OmniLLM(LLM):
165
+ model: Any = None
166
+ processor: Any = None
167
 
168
+ def __init__(self, model_path: str = "Qwen/Qwen2.5-Omni-7B"):
169
+ super().__init__()
170
+ self.model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
171
+ model_path,
172
+ torch_dtype="auto",
173
+ device_map="auto"
174
+ )
175
+ self.processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
 
 
 
176
 
177
+ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
178
+ # Implementation for text-only responses
179
+ conversation = [
180
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
181
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
182
+ ]
183
+
184
+ text = self.processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
185
+ audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
186
+ inputs = self.processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True)
187
+ inputs = inputs.to(self.model.device)
188
+
189
+ text_ids = self.model.generate(**inputs, return_audio=False)
190
+ response = self.processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
191
+ return response
192
+
193
+ @property
194
+ def _llm_type(self) -> str:
195
+ return "qwen25_omni"
196
 
197
+ def create_langchain_agent():
198
+ # Replace HuggingFaceHub with custom LLM
199
+ llm = Qwen25OmniLLM()
200
+
201
+ # Rest remains the same
202
+ tools = [CreateGraphTool()]
203
  memory = ConversationBufferWindowMemory(
204
  memory_key="chat_history",
205
  k=10,
206
  return_messages=True
207
  )
208
 
 
209
  agent = initialize_agent(
210
  tools=tools,
211
  llm=llm,
 
228
  agent = create_langchain_agent()
229
  return agent
230
 
231
+ def generate_voice_response(text_response: str, voice_enabled: bool = False) -> Optional[str]:
232
+ """Generate audio response if voice is enabled."""
233
+ if not voice_enabled:
234
+ return None
235
+
236
+ try:
237
+ current_agent = get_agent()
238
+ model = current_agent.llm.model
239
+ processor = current_agent.llm.processor
240
+
241
+ if not hasattr(model, 'generate') or not hasattr(model.generate, '__code__'):
242
+ logger.warning("Model may not support audio generation")
243
+ return None
244
+
245
+ conversation = [
246
+ {"role": "system", "content": [{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}]},
247
+ {"role": "user", "content": [{"type": "text", "text": "Please read this response aloud: " + text_response}]}
248
+ ]
249
+
250
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
251
+ audios, images, videos = process_mm_info(conversation, use_audio_in_video=False)
252
+ inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True)
253
+ inputs = inputs.to(model.device)
254
+
255
+ text_ids, audio = model.generate(**inputs, speaker="Ethan")
256
+
257
+ # Save audio to temporary file
258
+ audio_path = f"temp_audio_{int(time.time())}.wav"
259
+ sf.write(audio_path, audio.reshape(-1).detach().cpu().numpy(), samplerate=24000)
260
+ return audio_path
261
+
262
+ except Exception as e:
263
+ logger.error(f"Error generating voice response: {e}")
264
+ return None
265
+
266
+ def cleanup_temp_audio():
267
+ """Clean up temporary audio files on exit."""
268
+ for file in glob.glob("temp_audio_*.wav"):
269
+ try:
270
+ os.remove(file)
271
+ except:
272
+ pass
273
+
274
+ # Register cleanup function
275
+ atexit.register(cleanup_temp_audio)
276
+
277
  # --- UI: MathJax Configuration ---
278
  mathjax_config = '''
279
  <script>
 
371
  logger.info(f"Message type: {type(message)}")
372
  logger.info(f"Message content: {message}")
373
 
 
374
  try:
375
  metrics_tracker.log_interaction(message, "user_query", "chat_start")
376
  logger.info("Metrics interaction logged successfully")
 
407
  logger.error(f"Full traceback: {traceback.format_exc()}")
408
  return f"I apologize, but I encountered an error while processing your message: {str(e)}"
409
 
410
+ def respond_and_update(message, history, voice_enabled):
 
 
 
 
 
 
 
 
 
 
411
  """Main function to handle user submission."""
412
  if not message.strip():
413
+ return history, "", None
414
 
415
  # Add user message to history
416
  history.append({"role": "user", "content": message})
417
+ yield history, "", None
418
+
419
+ # Generate response directly (no mock streaming)
420
+ response = chat_response(message)
421
+ audio_path = generate_voice_response(response, voice_enabled) if voice_enabled else None
422
+
423
+ history.append({"role": "assistant", "content": response})
424
+ yield history, "", audio_path
 
 
 
 
 
425
 
426
  def clear_chat():
427
  """Clear the chat history and reset system prompt flag."""
 
431
  system_prompt_initialized = False
432
  return [], ""
433
 
 
434
  # --- UI: Interface Creation ---
435
  def create_interface():
436
  """Creates and configures the complete Gradio interface."""
 
441
  with open("styles.css", "r", encoding="utf-8") as css_file:
442
  custom_css = css_file.read()
443
  except FileNotFoundError:
444
+ logger.warning("styles.css file not found, using default styling")
445
  except Exception as e:
446
+ logger.warning(f"Error reading styles.css: {e}")
447
 
448
  with gr.Blocks(
449
  title="EduBot",
 
488
  with gr.Column(elem_classes=["button-column"], scale=1):
489
  send = gr.Button("Send", elem_classes=["send-button"], size="sm")
490
  clear = gr.Button("Clear", elem_classes=["clear-button"], size="sm")
491
+ voice_toggle = gr.Checkbox(label="Enable Voice (Ethan)", value=False, elem_classes=["voice-toggle"])
 
 
 
 
492
 
493
+ # Add audio output component
494
+ audio_output = gr.Audio(label="Voice Response", visible=True, autoplay=True)
495
+
496
+ # Event handlers - INSIDE the Blocks context
497
+ msg.submit(respond_and_update, [msg, chatbot, voice_toggle], [chatbot, msg, audio_output])
498
+ send.click(respond_and_update, [msg, chatbot, voice_toggle], [chatbot, msg, audio_output])
499
+ clear.click(clear_chat, outputs=[chatbot, msg])
500
+
501
+ # Apply CSS at the very end
502
+ gr.HTML(f'<style>{custom_css}</style>')
503
 
504
  return demo
505