Spaces:
Sleeping
Sleeping
Update backend
Browse files
main.py
CHANGED
|
@@ -12,12 +12,16 @@ import json
|
|
| 12 |
import os
|
| 13 |
import time
|
| 14 |
import torch
|
|
|
|
| 15 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
|
|
| 16 |
|
| 17 |
# Global model and tokenizer
|
| 18 |
model = None
|
| 19 |
tokenizer = None
|
| 20 |
device = None
|
|
|
|
|
|
|
| 21 |
|
| 22 |
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
| 23 |
@asynccontextmanager
|
|
@@ -59,6 +63,10 @@ async def lifespan(app: FastAPI):
|
|
| 59 |
model.eval()
|
| 60 |
print("Model loaded successfully!")
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
yield # Application runs here
|
| 63 |
|
| 64 |
# Shutdown (if needed)
|
|
@@ -188,92 +196,179 @@ async def classify(request: ClassificationRequest):
|
|
| 188 |
raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
|
| 189 |
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
@app.get("/api/examples")
|
| 192 |
async def get_examples():
|
| 193 |
-
"""Return example inputs
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
"
|
| 218 |
-
"
|
| 219 |
}
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
"
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
"
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
}
|
| 237 |
}
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
}
|
| 245 |
-
}
|
| 246 |
-
}
|
| 247 |
-
},
|
| 248 |
-
{
|
| 249 |
-
"name": "Incorrect Argument Type",
|
| 250 |
-
"description": "Argument has wrong data type",
|
| 251 |
-
"data": {
|
| 252 |
-
"query": "Set a reminder for 3pm",
|
| 253 |
-
"enabled_tools": [
|
| 254 |
-
{
|
| 255 |
-
"name": "set_reminder",
|
| 256 |
-
"description": "Create a reminder",
|
| 257 |
-
"parameters": {
|
| 258 |
-
"type": "object",
|
| 259 |
-
"properties": {
|
| 260 |
-
"time": {"type": "string"},
|
| 261 |
-
"message": {"type": "string"}
|
| 262 |
-
}
|
| 263 |
}
|
| 264 |
}
|
| 265 |
-
],
|
| 266 |
-
"tool_calling": {
|
| 267 |
-
"name": "set_reminder",
|
| 268 |
-
"arguments": {
|
| 269 |
-
"time": 1500, # Should be string!
|
| 270 |
-
"message": "Meeting"
|
| 271 |
-
}
|
| 272 |
}
|
| 273 |
}
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
| 277 |
return {"examples": examples}
|
| 278 |
|
| 279 |
|
|
|
|
| 12 |
import os
|
| 13 |
import time
|
| 14 |
import torch
|
| 15 |
+
import random
|
| 16 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 17 |
+
from collections import defaultdict
|
| 18 |
|
| 19 |
# Global model and tokenizer
|
| 20 |
model = None
|
| 21 |
tokenizer = None
|
| 22 |
device = None
|
| 23 |
+
dataset_by_label = None
|
| 24 |
+
dataset_path = None
|
| 25 |
|
| 26 |
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
| 27 |
@asynccontextmanager
|
|
|
|
| 63 |
model.eval()
|
| 64 |
print("Model loaded successfully!")
|
| 65 |
|
| 66 |
+
# Load dataset for examples
|
| 67 |
+
print("Loading dataset for examples...")
|
| 68 |
+
load_dataset()
|
| 69 |
+
|
| 70 |
yield # Application runs here
|
| 71 |
|
| 72 |
# Shutdown (if needed)
|
|
|
|
| 196 |
raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
|
| 197 |
|
| 198 |
|
| 199 |
+
def load_dataset():
|
| 200 |
+
"""Load dataset and group examples by label"""
|
| 201 |
+
global dataset_by_label, dataset_path
|
| 202 |
+
|
| 203 |
+
# Try to find dataset file - check multiple possible locations
|
| 204 |
+
possible_paths = [
|
| 205 |
+
os.path.join(os.path.dirname(__file__), "../../dataset/xlam-function-calling-60k/xlam_function_calling_60k_processed_with_ground_truth.json"),
|
| 206 |
+
"/work/cssema416/202610/12/dataset/xlam-function-calling-60k/xlam_function_calling_60k_processed_with_ground_truth.json",
|
| 207 |
+
os.getenv("DATASET_PATH", ""),
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
dataset_path = None
|
| 211 |
+
for path in possible_paths:
|
| 212 |
+
if path and os.path.exists(path):
|
| 213 |
+
dataset_path = path
|
| 214 |
+
break
|
| 215 |
+
|
| 216 |
+
if not dataset_path:
|
| 217 |
+
print("Warning: Dataset file not found. Using hardcoded examples.")
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
print(f"Loading dataset from: {dataset_path}")
|
| 222 |
+
with open(dataset_path, 'r') as f:
|
| 223 |
+
data = json.load(f)
|
| 224 |
+
|
| 225 |
+
# Group examples by ground_truth label
|
| 226 |
+
dataset_by_label = defaultdict(list)
|
| 227 |
+
for item in data:
|
| 228 |
+
label = item.get('ground_truth', 'Unknown')
|
| 229 |
+
if label in LABEL_MAP.values():
|
| 230 |
+
dataset_by_label[label].append(item)
|
| 231 |
+
|
| 232 |
+
print(f"Loaded {len(data)} examples. Examples per label: {dict((k, len(v)) for k, v in dataset_by_label.items())}")
|
| 233 |
+
return dataset_by_label
|
| 234 |
+
except Exception as e:
|
| 235 |
+
print(f"Error loading dataset: {e}")
|
| 236 |
+
return None
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def convert_dataset_example_to_api_format(item: Dict[str, Any]) -> Dict[str, Any]:
|
| 240 |
+
"""Convert dataset example to API format"""
|
| 241 |
+
# Convert tools format
|
| 242 |
+
enabled_tools = []
|
| 243 |
+
for tool in item.get('tools', []):
|
| 244 |
+
# Convert parameters from dict format to JSON Schema format
|
| 245 |
+
properties = {}
|
| 246 |
+
required = []
|
| 247 |
+
|
| 248 |
+
tool_params = tool.get('parameters', {})
|
| 249 |
+
if isinstance(tool_params, dict):
|
| 250 |
+
for param_name, param_info in tool_params.items():
|
| 251 |
+
if isinstance(param_info, dict):
|
| 252 |
+
param_type = param_info.get('type', 'string')
|
| 253 |
+
# Map Python types to JSON types
|
| 254 |
+
type_mapping = {
|
| 255 |
+
'str': 'string',
|
| 256 |
+
'int': 'integer',
|
| 257 |
+
'float': 'number',
|
| 258 |
+
'bool': 'boolean',
|
| 259 |
+
'list': 'array',
|
| 260 |
+
'dict': 'object'
|
| 261 |
+
}
|
| 262 |
+
json_type = type_mapping.get(param_type, 'string')
|
| 263 |
+
|
| 264 |
+
prop = {"type": json_type}
|
| 265 |
+
if 'description' in param_info:
|
| 266 |
+
prop['description'] = param_info['description']
|
| 267 |
+
if 'enum' in param_info:
|
| 268 |
+
prop['enum'] = param_info['enum']
|
| 269 |
+
if 'default' not in param_info: # If no default, might be required
|
| 270 |
+
required.append(param_name)
|
| 271 |
+
|
| 272 |
+
properties[param_name] = prop
|
| 273 |
+
|
| 274 |
+
tool_schema = {
|
| 275 |
+
"name": tool.get('name', ''),
|
| 276 |
+
"description": tool.get('description', ''),
|
| 277 |
+
"parameters": {
|
| 278 |
+
"type": "object",
|
| 279 |
+
"properties": properties
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
if required:
|
| 283 |
+
tool_schema["parameters"]["required"] = required
|
| 284 |
+
|
| 285 |
+
enabled_tools.append(tool_schema)
|
| 286 |
+
|
| 287 |
+
# Get tool calling from answers
|
| 288 |
+
tool_calling = None
|
| 289 |
+
if item.get('answers') and len(item['answers']) > 0:
|
| 290 |
+
answer = item['answers'][0]
|
| 291 |
+
tool_calling = {
|
| 292 |
+
"name": answer.get('name', ''),
|
| 293 |
+
"arguments": answer.get('arguments', {})
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
return {
|
| 297 |
+
"query": item.get('query', ''),
|
| 298 |
+
"enabled_tools": enabled_tools,
|
| 299 |
+
"tool_calling": tool_calling
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
|
| 303 |
@app.get("/api/examples")
|
| 304 |
async def get_examples():
|
| 305 |
+
"""Return random example inputs from dataset, grouped by label"""
|
| 306 |
+
global dataset_by_label
|
| 307 |
+
|
| 308 |
+
# Load dataset if not already loaded
|
| 309 |
+
if dataset_by_label is None:
|
| 310 |
+
load_dataset()
|
| 311 |
+
|
| 312 |
+
examples = []
|
| 313 |
+
|
| 314 |
+
# If dataset is loaded, get random examples from each label
|
| 315 |
+
if dataset_by_label:
|
| 316 |
+
# Get one random example from each label
|
| 317 |
+
for label in LABEL_MAP.values():
|
| 318 |
+
if label in dataset_by_label and len(dataset_by_label[label]) > 0:
|
| 319 |
+
# Randomly select an example from this label
|
| 320 |
+
random_example = random.choice(dataset_by_label[label])
|
| 321 |
+
|
| 322 |
+
# Convert to API format
|
| 323 |
+
try:
|
| 324 |
+
api_format = convert_dataset_example_to_api_format(random_example)
|
| 325 |
+
|
| 326 |
+
# Create example entry
|
| 327 |
+
example_entry = {
|
| 328 |
+
"name": f"{label} Example",
|
| 329 |
+
"description": f"Example of {label.replace('_', ' ').title()}",
|
| 330 |
+
"data": api_format
|
| 331 |
}
|
| 332 |
+
examples.append(example_entry)
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"Error converting example for label {label}: {e}")
|
| 335 |
+
continue
|
| 336 |
+
else:
|
| 337 |
+
# Fallback to hardcoded examples if dataset not available
|
| 338 |
+
examples = [
|
| 339 |
+
{
|
| 340 |
+
"name": "Correct Example",
|
| 341 |
+
"description": "A properly formed tool call",
|
| 342 |
+
"data": {
|
| 343 |
+
"query": "What's the weather in New York?",
|
| 344 |
+
"enabled_tools": [
|
| 345 |
+
{
|
| 346 |
+
"name": "get_weather",
|
| 347 |
+
"description": "Get current weather for a location",
|
| 348 |
+
"parameters": {
|
| 349 |
+
"type": "object",
|
| 350 |
+
"properties": {
|
| 351 |
+
"location": {"type": "string"},
|
| 352 |
+
"units": {"type": "string", "enum": ["celsius", "fahrenheit"]}
|
| 353 |
+
},
|
| 354 |
+
"required": ["location"]
|
| 355 |
}
|
| 356 |
}
|
| 357 |
+
],
|
| 358 |
+
"tool_calling": {
|
| 359 |
+
"name": "get_weather",
|
| 360 |
+
"arguments": {
|
| 361 |
+
"location": "New York",
|
| 362 |
+
"units": "fahrenheit"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
}
|
| 364 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
}
|
| 366 |
}
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
# Shuffle examples to randomize order
|
| 370 |
+
random.shuffle(examples)
|
| 371 |
+
|
| 372 |
return {"examples": examples}
|
| 373 |
|
| 374 |
|