jdesiree commited on
Commit
43fdc71
·
verified ·
1 Parent(s): 3fa11a9

Graph Tool Integration

Browse files
Files changed (1) hide show
  1. app.py +88 -43
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import gradio as gr
2
- from langchain.prompts import ChatPromptTemplate
3
- from langchain.schema import HumanMessage, SystemMessage, AIMessage
4
- from huggingface_hub import InferenceClient
5
  from metrics import EduBotMetrics
 
 
6
  import os
7
  import time
8
  import logging
 
9
  import re
10
 
11
  # --- Environment and Logging Setup ---
@@ -18,13 +19,34 @@ if not hf_token:
18
  logger.warning("Neither HF_TOKEN nor HUGGINGFACEHUB_API_TOKEN is set, the application may not work.")
19
 
20
  # --- LLM Configuration ---
21
- client = InferenceClient(
22
- provider="together",
23
- api_key=hf_token,
24
- )
25
 
26
  metrics_tracker = EduBotMetrics(save_file="edu_metrics.json")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # --- LLM Templates ---
29
  # Enhanced base system message
30
  SYSTEM_MESSAGE = """You are EduBot, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
@@ -148,76 +170,99 @@ def smart_truncate(text, max_length=3000):
148
  return ' '.join(words[:-1]) + "... [Response truncated - ask for continuation]"
149
 
150
  def respond_with_enhanced_streaming(message, history):
151
- """Streams the bot's response, handling errors with metrics tracking."""
152
-
153
- # Start metrics timing
154
  timing_context = metrics_tracker.start_timing()
155
  error_occurred = False
156
  error_message = None
157
- response = ""
158
 
159
  try:
160
- # Build conversation history (last 5 exchanges)
161
  api_messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
162
  if history:
 
163
  for exchange in history[-5:]:
164
- if exchange.get("role") == "user":
165
- api_messages.append({"role": "user", "content": exchange["content"]})
166
- elif exchange.get("role") == "assistant":
167
- api_messages.append({"role": "assistant", "content": exchange["content"]})
168
 
169
- # Add current user message
170
  api_messages.append({"role": "user", "content": message})
171
-
172
- # Mark provider API start
173
  metrics_tracker.mark_provider_start(timing_context)
174
 
175
- completion = client.chat.completions.create(
176
  model="Qwen/Qwen2.5-7B-Instruct",
177
  messages=api_messages,
178
  max_tokens=4096,
179
  temperature=0.7,
180
  top_p=0.9,
 
 
181
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- # Mark provider API end
184
  metrics_tracker.mark_provider_end(timing_context)
 
185
 
186
- response = smart_truncate(completion.choices[0].message.content, max_length=3000)
187
-
188
- # Stream fake-chunks word by word
189
- words = response.split()
190
- partial_response = ""
191
-
192
- for i, word in enumerate(words):
193
- partial_response += word + " "
194
- if i % 4 == 0:
195
- metrics_tracker.record_chunk(timing_context)
196
- yield partial_response
197
- time.sleep(0.03)
198
-
199
- logger.info(f"Response completed. Length: {len(response)} characters")
200
-
201
- metrics_tracker.record_chunk(timing_context)
202
- yield response
203
-
204
  except Exception as e:
205
  error_occurred = True
206
  error_message = str(e)
207
  logger.exception("Error in response generation")
208
- yield "Sorry, I encountered an error while generating the response."
209
 
210
  finally:
211
  metrics_tracker.log_interaction(
212
  query=message,
213
- response=response,
214
  timing_context=timing_context,
215
  error_occurred=error_occurred,
216
  error_message=error_message,
217
  )
218
 
219
-
220
-
221
  # ===============================================================================
222
  # UI CONFIGURATION SECTION - ALL UI RELATED CODE CENTRALIZED HERE
223
  # ===============================================================================
 
1
  import gradio as gr
2
+ from graph_tool import generate_plot
 
 
3
  from metrics import EduBotMetrics
4
+ from together import Together
5
+ from together.types.chat.completion import CompletionChunk
6
  import os
7
  import time
8
  import logging
9
+ import json
10
  import re
11
 
12
  # --- Environment and Logging Setup ---
 
19
  logger.warning("Neither HF_TOKEN nor HUGGINGFACEHUB_API_TOKEN is set, the application may not work.")
20
 
21
  # --- LLM Configuration ---
22
+ client = Together(api_key=hf_token)
 
 
 
23
 
24
  metrics_tracker = EduBotMetrics(save_file="edu_metrics.json")
25
 
26
+ # --- Tools ---
27
+
28
+ tools = [
29
+ {
30
+ "name": "create_graph",
31
+ "description": "Generates a plot (bar, line, or pie) and returns it as an HTML-formatted Base64-encoded image string. The data and labels arguments must be JSON-encoded strings. Use simple, descriptive labels for plots. Use this tool to produce plots for practice questions, data visualization use cases, or any other case where the student may benefit from a graph.",
32
+ "parameters": {
33
+ "type": "object",
34
+ "properties": {
35
+ "data_json": {"type": "string"},
36
+ "labels_json": {"type": "string"},
37
+ "plot_type": {"type": "string", "enum": ["bar", "line", "pie"]},
38
+ "title": {"type": "string"},
39
+ "x_label": {"type": "string"},
40
+ "y_label": {"type": "string"},
41
+ },
42
+ "required": ["data_json", "labels_json", "plot_type", "title"],
43
+ },
44
+ }
45
+ ]
46
+
47
+ # --- BIND THE TOOL TO THE MODEL ---
48
+ model_with_tools = llm.bind_tools([create_graph])
49
+
50
  # --- LLM Templates ---
51
  # Enhanced base system message
52
  SYSTEM_MESSAGE = """You are EduBot, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
 
170
  return ' '.join(words[:-1]) + "... [Response truncated - ask for continuation]"
171
 
172
  def respond_with_enhanced_streaming(message, history):
173
+ """Streams the bot's response, handling tool calls and errors with metrics tracking."""
 
 
174
  timing_context = metrics_tracker.start_timing()
175
  error_occurred = False
176
  error_message = None
177
+ full_response = ""
178
 
179
  try:
 
180
  api_messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
181
  if history:
182
+ # Prepare history for API call
183
  for exchange in history[-5:]:
184
+ api_messages.append({"role": "user", "content": exchange[0]})
185
+ api_messages.append({"role": "assistant", "content": exchange[1]})
 
 
186
 
 
187
  api_messages.append({"role": "user", "content": message})
188
+
 
189
  metrics_tracker.mark_provider_start(timing_context)
190
 
191
+ stream = client.chat.completions.create(
192
  model="Qwen/Qwen2.5-7B-Instruct",
193
  messages=api_messages,
194
  max_tokens=4096,
195
  temperature=0.7,
196
  top_p=0.9,
197
+ stream=True,
198
+ tools=tools, # Pass the tool definitions here
199
  )
200
+
201
+ # Buffers to handle multi-chunk tool calls
202
+ tool_call_name = ""
203
+ tool_call_args_str = ""
204
+
205
+ for chunk in stream:
206
+ if isinstance(chunk, CompletionChunk):
207
+ # Handle text chunks
208
+ if chunk.choices and chunk.choices[0].delta.content:
209
+ text_chunk = chunk.choices[0].delta.content
210
+ full_response += text_chunk
211
+ yield full_response
212
+
213
+ # Handle tool call chunks
214
+ if chunk.choices and chunk.choices[0].delta.tool_calls:
215
+ tool_call_delta = chunk.choices[0].delta.tool_calls[0]
216
+
217
+ # Accumulate name and arguments from stream
218
+ if tool_call_delta.function.name:
219
+ tool_call_name = tool_call_delta.function.name
220
+ if tool_call_delta.function.arguments:
221
+ tool_call_args_str += tool_call_delta.function.arguments
222
+
223
+ # Check if we have received the full tool call
224
+ # This is a simple heuristic and might need refinement
225
+ if tool_call_name and '}' in tool_call_args_str:
226
+ try:
227
+ tool_args = json.loads(tool_call_args_str)
228
+ if tool_call_name == "create_graph":
229
+ logger.info(f"Executing tool: {tool_call_name} with args: {tool_args}")
230
+ graph_html = generate_plot(**tool_args)
231
+ full_response += graph_html
232
+ yield full_response
233
+
234
+ # Reset buffers
235
+ tool_call_name = ""
236
+ tool_call_args_str = ""
237
+
238
+ except json.JSONDecodeError:
239
+ logger.error("JSON parsing failed for tool arguments.")
240
+ # Yield an error or handle it silently
241
+ full_response += f"<p style='color:red;'>Error parsing graph data.</p>"
242
+ yield full_response
243
+ except Exception as e:
244
+ logger.exception("Error executing tool")
245
+ full_response += f"<p style='color:red;'>Error executing tool: {e}</p>"
246
+ yield full_response
247
 
 
248
  metrics_tracker.mark_provider_end(timing_context)
249
+ logger.info(f"Response completed. Length: {len(full_response)} characters")
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  except Exception as e:
252
  error_occurred = True
253
  error_message = str(e)
254
  logger.exception("Error in response generation")
255
+ yield "Sorry, an error occurred while generating the response."
256
 
257
  finally:
258
  metrics_tracker.log_interaction(
259
  query=message,
260
+ response=full_response,
261
  timing_context=timing_context,
262
  error_occurred=error_occurred,
263
  error_message=error_message,
264
  )
265
 
 
 
266
  # ===============================================================================
267
  # UI CONFIGURATION SECTION - ALL UI RELATED CODE CENTRALIZED HERE
268
  # ===============================================================================