Spaces:
Sleeping
Sleeping
Graph Tool Integration
Browse files
app.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from
|
| 3 |
-
from langchain.schema import HumanMessage, SystemMessage, AIMessage
|
| 4 |
-
from huggingface_hub import InferenceClient
|
| 5 |
from metrics import EduBotMetrics
|
|
|
|
|
|
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
import logging
|
|
|
|
| 9 |
import re
|
| 10 |
|
| 11 |
# --- Environment and Logging Setup ---
|
|
@@ -18,13 +19,34 @@ if not hf_token:
|
|
| 18 |
logger.warning("Neither HF_TOKEN nor HUGGINGFACEHUB_API_TOKEN is set, the application may not work.")
|
| 19 |
|
| 20 |
# --- LLM Configuration ---
|
| 21 |
-
client =
|
| 22 |
-
provider="together",
|
| 23 |
-
api_key=hf_token,
|
| 24 |
-
)
|
| 25 |
|
| 26 |
metrics_tracker = EduBotMetrics(save_file="edu_metrics.json")
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# --- LLM Templates ---
|
| 29 |
# Enhanced base system message
|
| 30 |
SYSTEM_MESSAGE = """You are EduBot, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
|
|
@@ -148,76 +170,99 @@ def smart_truncate(text, max_length=3000):
|
|
| 148 |
return ' '.join(words[:-1]) + "... [Response truncated - ask for continuation]"
|
| 149 |
|
| 150 |
def respond_with_enhanced_streaming(message, history):
|
| 151 |
-
"""Streams the bot's response, handling errors with metrics tracking."""
|
| 152 |
-
|
| 153 |
-
# Start metrics timing
|
| 154 |
timing_context = metrics_tracker.start_timing()
|
| 155 |
error_occurred = False
|
| 156 |
error_message = None
|
| 157 |
-
|
| 158 |
|
| 159 |
try:
|
| 160 |
-
# Build conversation history (last 5 exchanges)
|
| 161 |
api_messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
|
| 162 |
if history:
|
|
|
|
| 163 |
for exchange in history[-5:]:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
elif exchange.get("role") == "assistant":
|
| 167 |
-
api_messages.append({"role": "assistant", "content": exchange["content"]})
|
| 168 |
|
| 169 |
-
# Add current user message
|
| 170 |
api_messages.append({"role": "user", "content": message})
|
| 171 |
-
|
| 172 |
-
# Mark provider API start
|
| 173 |
metrics_tracker.mark_provider_start(timing_context)
|
| 174 |
|
| 175 |
-
|
| 176 |
model="Qwen/Qwen2.5-7B-Instruct",
|
| 177 |
messages=api_messages,
|
| 178 |
max_tokens=4096,
|
| 179 |
temperature=0.7,
|
| 180 |
top_p=0.9,
|
|
|
|
|
|
|
| 181 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
# Mark provider API end
|
| 184 |
metrics_tracker.mark_provider_end(timing_context)
|
|
|
|
| 185 |
|
| 186 |
-
response = smart_truncate(completion.choices[0].message.content, max_length=3000)
|
| 187 |
-
|
| 188 |
-
# Stream fake-chunks word by word
|
| 189 |
-
words = response.split()
|
| 190 |
-
partial_response = ""
|
| 191 |
-
|
| 192 |
-
for i, word in enumerate(words):
|
| 193 |
-
partial_response += word + " "
|
| 194 |
-
if i % 4 == 0:
|
| 195 |
-
metrics_tracker.record_chunk(timing_context)
|
| 196 |
-
yield partial_response
|
| 197 |
-
time.sleep(0.03)
|
| 198 |
-
|
| 199 |
-
logger.info(f"Response completed. Length: {len(response)} characters")
|
| 200 |
-
|
| 201 |
-
metrics_tracker.record_chunk(timing_context)
|
| 202 |
-
yield response
|
| 203 |
-
|
| 204 |
except Exception as e:
|
| 205 |
error_occurred = True
|
| 206 |
error_message = str(e)
|
| 207 |
logger.exception("Error in response generation")
|
| 208 |
-
yield "Sorry,
|
| 209 |
|
| 210 |
finally:
|
| 211 |
metrics_tracker.log_interaction(
|
| 212 |
query=message,
|
| 213 |
-
response=
|
| 214 |
timing_context=timing_context,
|
| 215 |
error_occurred=error_occurred,
|
| 216 |
error_message=error_message,
|
| 217 |
)
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
# ===============================================================================
|
| 222 |
# UI CONFIGURATION SECTION - ALL UI RELATED CODE CENTRALIZED HERE
|
| 223 |
# ===============================================================================
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from graph_tool import generate_plot
|
|
|
|
|
|
|
| 3 |
from metrics import EduBotMetrics
|
| 4 |
+
from together import Together
|
| 5 |
+
from together.types.chat.completion import CompletionChunk
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
import logging
|
| 9 |
+
import json
|
| 10 |
import re
|
| 11 |
|
| 12 |
# --- Environment and Logging Setup ---
|
|
|
|
| 19 |
logger.warning("Neither HF_TOKEN nor HUGGINGFACEHUB_API_TOKEN is set, the application may not work.")
|
| 20 |
|
| 21 |
# --- LLM Configuration ---
|
| 22 |
+
client = Together(api_key=hf_token)
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
metrics_tracker = EduBotMetrics(save_file="edu_metrics.json")
|
| 25 |
|
| 26 |
+
# --- Tools ---
|
| 27 |
+
|
| 28 |
+
tools = [
|
| 29 |
+
{
|
| 30 |
+
"name": "create_graph",
|
| 31 |
+
"description": "Generates a plot (bar, line, or pie) and returns it as an HTML-formatted Base64-encoded image string. The data and labels arguments must be JSON-encoded strings. Use simple, descriptive labels for plots. Use this tool to produce plots for practice questions, data visualization use cases, or any other case where the student may benefit from a graph.",
|
| 32 |
+
"parameters": {
|
| 33 |
+
"type": "object",
|
| 34 |
+
"properties": {
|
| 35 |
+
"data_json": {"type": "string"},
|
| 36 |
+
"labels_json": {"type": "string"},
|
| 37 |
+
"plot_type": {"type": "string", "enum": ["bar", "line", "pie"]},
|
| 38 |
+
"title": {"type": "string"},
|
| 39 |
+
"x_label": {"type": "string"},
|
| 40 |
+
"y_label": {"type": "string"},
|
| 41 |
+
},
|
| 42 |
+
"required": ["data_json", "labels_json", "plot_type", "title"],
|
| 43 |
+
},
|
| 44 |
+
}
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# --- BIND THE TOOL TO THE MODEL ---
|
| 48 |
+
model_with_tools = llm.bind_tools([create_graph])
|
| 49 |
+
|
| 50 |
# --- LLM Templates ---
|
| 51 |
# Enhanced base system message
|
| 52 |
SYSTEM_MESSAGE = """You are EduBot, an expert multi-concept tutor designed to facilitate genuine learning and understanding. Your primary mission is to guide students through the learning process rather than providing direct answers to academic work.
|
|
|
|
| 170 |
return ' '.join(words[:-1]) + "... [Response truncated - ask for continuation]"
|
| 171 |
|
| 172 |
def respond_with_enhanced_streaming(message, history):
|
| 173 |
+
"""Streams the bot's response, handling tool calls and errors with metrics tracking."""
|
|
|
|
|
|
|
| 174 |
timing_context = metrics_tracker.start_timing()
|
| 175 |
error_occurred = False
|
| 176 |
error_message = None
|
| 177 |
+
full_response = ""
|
| 178 |
|
| 179 |
try:
|
|
|
|
| 180 |
api_messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
|
| 181 |
if history:
|
| 182 |
+
# Prepare history for API call
|
| 183 |
for exchange in history[-5:]:
|
| 184 |
+
api_messages.append({"role": "user", "content": exchange[0]})
|
| 185 |
+
api_messages.append({"role": "assistant", "content": exchange[1]})
|
|
|
|
|
|
|
| 186 |
|
|
|
|
| 187 |
api_messages.append({"role": "user", "content": message})
|
| 188 |
+
|
|
|
|
| 189 |
metrics_tracker.mark_provider_start(timing_context)
|
| 190 |
|
| 191 |
+
stream = client.chat.completions.create(
|
| 192 |
model="Qwen/Qwen2.5-7B-Instruct",
|
| 193 |
messages=api_messages,
|
| 194 |
max_tokens=4096,
|
| 195 |
temperature=0.7,
|
| 196 |
top_p=0.9,
|
| 197 |
+
stream=True,
|
| 198 |
+
tools=tools, # Pass the tool definitions here
|
| 199 |
)
|
| 200 |
+
|
| 201 |
+
# Buffers to handle multi-chunk tool calls
|
| 202 |
+
tool_call_name = ""
|
| 203 |
+
tool_call_args_str = ""
|
| 204 |
+
|
| 205 |
+
for chunk in stream:
|
| 206 |
+
if isinstance(chunk, CompletionChunk):
|
| 207 |
+
# Handle text chunks
|
| 208 |
+
if chunk.choices and chunk.choices[0].delta.content:
|
| 209 |
+
text_chunk = chunk.choices[0].delta.content
|
| 210 |
+
full_response += text_chunk
|
| 211 |
+
yield full_response
|
| 212 |
+
|
| 213 |
+
# Handle tool call chunks
|
| 214 |
+
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
| 215 |
+
tool_call_delta = chunk.choices[0].delta.tool_calls[0]
|
| 216 |
+
|
| 217 |
+
# Accumulate name and arguments from stream
|
| 218 |
+
if tool_call_delta.function.name:
|
| 219 |
+
tool_call_name = tool_call_delta.function.name
|
| 220 |
+
if tool_call_delta.function.arguments:
|
| 221 |
+
tool_call_args_str += tool_call_delta.function.arguments
|
| 222 |
+
|
| 223 |
+
# Check if we have received the full tool call
|
| 224 |
+
# This is a simple heuristic and might need refinement
|
| 225 |
+
if tool_call_name and '}' in tool_call_args_str:
|
| 226 |
+
try:
|
| 227 |
+
tool_args = json.loads(tool_call_args_str)
|
| 228 |
+
if tool_call_name == "create_graph":
|
| 229 |
+
logger.info(f"Executing tool: {tool_call_name} with args: {tool_args}")
|
| 230 |
+
graph_html = generate_plot(**tool_args)
|
| 231 |
+
full_response += graph_html
|
| 232 |
+
yield full_response
|
| 233 |
+
|
| 234 |
+
# Reset buffers
|
| 235 |
+
tool_call_name = ""
|
| 236 |
+
tool_call_args_str = ""
|
| 237 |
+
|
| 238 |
+
except json.JSONDecodeError:
|
| 239 |
+
logger.error("JSON parsing failed for tool arguments.")
|
| 240 |
+
# Yield an error or handle it silently
|
| 241 |
+
full_response += f"<p style='color:red;'>Error parsing graph data.</p>"
|
| 242 |
+
yield full_response
|
| 243 |
+
except Exception as e:
|
| 244 |
+
logger.exception("Error executing tool")
|
| 245 |
+
full_response += f"<p style='color:red;'>Error executing tool: {e}</p>"
|
| 246 |
+
yield full_response
|
| 247 |
|
|
|
|
| 248 |
metrics_tracker.mark_provider_end(timing_context)
|
| 249 |
+
logger.info(f"Response completed. Length: {len(full_response)} characters")
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
except Exception as e:
|
| 252 |
error_occurred = True
|
| 253 |
error_message = str(e)
|
| 254 |
logger.exception("Error in response generation")
|
| 255 |
+
yield "Sorry, an error occurred while generating the response."
|
| 256 |
|
| 257 |
finally:
|
| 258 |
metrics_tracker.log_interaction(
|
| 259 |
query=message,
|
| 260 |
+
response=full_response,
|
| 261 |
timing_context=timing_context,
|
| 262 |
error_occurred=error_occurred,
|
| 263 |
error_message=error_message,
|
| 264 |
)
|
| 265 |
|
|
|
|
|
|
|
| 266 |
# ===============================================================================
|
| 267 |
# UI CONFIGURATION SECTION - ALL UI RELATED CODE CENTRALIZED HERE
|
| 268 |
# ===============================================================================
|