daoqm123 commited on
Commit
56db9b3
·
1 Parent(s): 877b44a

Update backend

Browse files
Files changed (1) hide show
  1. main.py +172 -77
main.py CHANGED
@@ -12,12 +12,16 @@ import json
12
  import os
13
  import time
14
  import torch
 
15
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
16
 
17
  # Global model and tokenizer
18
  model = None
19
  tokenizer = None
20
  device = None
 
 
21
 
22
  os.environ["CUDA_VISIBLE_DEVICES"] = "7"
23
  @asynccontextmanager
@@ -59,6 +63,10 @@ async def lifespan(app: FastAPI):
59
  model.eval()
60
  print("Model loaded successfully!")
61
 
 
 
 
 
62
  yield # Application runs here
63
 
64
  # Shutdown (if needed)
@@ -188,92 +196,179 @@ async def classify(request: ClassificationRequest):
188
  raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  @app.get("/api/examples")
192
  async def get_examples():
193
- """Return example inputs for testing"""
194
- examples = [
195
- {
196
- "name": "Correct Example",
197
- "description": "A properly formed tool call",
198
- "data": {
199
- "query": "What's the weather in New York?",
200
- "enabled_tools": [
201
- {
202
- "name": "get_weather",
203
- "description": "Get current weather for a location",
204
- "parameters": {
205
- "type": "object",
206
- "properties": {
207
- "location": {"type": "string"},
208
- "units": {"type": "string", "enum": ["celsius", "fahrenheit"]}
209
- },
210
- "required": ["location"]
211
- }
212
- }
213
- ],
214
- "tool_calling": {
215
- "name": "get_weather",
216
- "arguments": {
217
- "location": "New York",
218
- "units": "fahrenheit"
219
  }
220
- }
221
- }
222
- },
223
- {
224
- "name": "Wrong Function Name",
225
- "description": "Tool call uses incorrect function name",
226
- "data": {
227
- "query": "Calculate 25 * 4",
228
- "enabled_tools": [
229
- {
230
- "name": "calculator",
231
- "description": "Perform calculations",
232
- "parameters": {
233
- "type": "object",
234
- "properties": {
235
- "expression": {"type": "string"}
 
 
 
 
 
 
 
236
  }
237
  }
238
- }
239
- ],
240
- "tool_calling": {
241
- "name": "calculate", # Wrong name!
242
- "arguments": {
243
- "expression": "25 * 4"
244
- }
245
- }
246
- }
247
- },
248
- {
249
- "name": "Incorrect Argument Type",
250
- "description": "Argument has wrong data type",
251
- "data": {
252
- "query": "Set a reminder for 3pm",
253
- "enabled_tools": [
254
- {
255
- "name": "set_reminder",
256
- "description": "Create a reminder",
257
- "parameters": {
258
- "type": "object",
259
- "properties": {
260
- "time": {"type": "string"},
261
- "message": {"type": "string"}
262
- }
263
  }
264
  }
265
- ],
266
- "tool_calling": {
267
- "name": "set_reminder",
268
- "arguments": {
269
- "time": 1500, # Should be string!
270
- "message": "Meeting"
271
- }
272
  }
273
  }
274
- }
275
- ]
276
-
 
 
277
  return {"examples": examples}
278
 
279
 
 
12
  import os
13
  import time
14
  import torch
15
+ import random
16
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
17
+ from collections import defaultdict
18
 
19
  # Global model and tokenizer
20
  model = None
21
  tokenizer = None
22
  device = None
23
+ dataset_by_label = None
24
+ dataset_path = None
25
 
26
  os.environ["CUDA_VISIBLE_DEVICES"] = "7"
27
  @asynccontextmanager
 
63
  model.eval()
64
  print("Model loaded successfully!")
65
 
66
+ # Load dataset for examples
67
+ print("Loading dataset for examples...")
68
+ load_dataset()
69
+
70
  yield # Application runs here
71
 
72
  # Shutdown (if needed)
 
196
  raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
197
 
198
 
199
+ def load_dataset():
200
+ """Load dataset and group examples by label"""
201
+ global dataset_by_label, dataset_path
202
+
203
+ # Try to find dataset file - check multiple possible locations
204
+ possible_paths = [
205
+ os.path.join(os.path.dirname(__file__), "../../dataset/xlam-function-calling-60k/xlam_function_calling_60k_processed_with_ground_truth.json"),
206
+ "/work/cssema416/202610/12/dataset/xlam-function-calling-60k/xlam_function_calling_60k_processed_with_ground_truth.json",
207
+ os.getenv("DATASET_PATH", ""),
208
+ ]
209
+
210
+ dataset_path = None
211
+ for path in possible_paths:
212
+ if path and os.path.exists(path):
213
+ dataset_path = path
214
+ break
215
+
216
+ if not dataset_path:
217
+ print("Warning: Dataset file not found. Using hardcoded examples.")
218
+ return None
219
+
220
+ try:
221
+ print(f"Loading dataset from: {dataset_path}")
222
+ with open(dataset_path, 'r') as f:
223
+ data = json.load(f)
224
+
225
+ # Group examples by ground_truth label
226
+ dataset_by_label = defaultdict(list)
227
+ for item in data:
228
+ label = item.get('ground_truth', 'Unknown')
229
+ if label in LABEL_MAP.values():
230
+ dataset_by_label[label].append(item)
231
+
232
+ print(f"Loaded {len(data)} examples. Examples per label: {dict((k, len(v)) for k, v in dataset_by_label.items())}")
233
+ return dataset_by_label
234
+ except Exception as e:
235
+ print(f"Error loading dataset: {e}")
236
+ return None
237
+
238
+
239
+ def convert_dataset_example_to_api_format(item: Dict[str, Any]) -> Dict[str, Any]:
240
+ """Convert dataset example to API format"""
241
+ # Convert tools format
242
+ enabled_tools = []
243
+ for tool in item.get('tools', []):
244
+ # Convert parameters from dict format to JSON Schema format
245
+ properties = {}
246
+ required = []
247
+
248
+ tool_params = tool.get('parameters', {})
249
+ if isinstance(tool_params, dict):
250
+ for param_name, param_info in tool_params.items():
251
+ if isinstance(param_info, dict):
252
+ param_type = param_info.get('type', 'string')
253
+ # Map Python types to JSON types
254
+ type_mapping = {
255
+ 'str': 'string',
256
+ 'int': 'integer',
257
+ 'float': 'number',
258
+ 'bool': 'boolean',
259
+ 'list': 'array',
260
+ 'dict': 'object'
261
+ }
262
+ json_type = type_mapping.get(param_type, 'string')
263
+
264
+ prop = {"type": json_type}
265
+ if 'description' in param_info:
266
+ prop['description'] = param_info['description']
267
+ if 'enum' in param_info:
268
+ prop['enum'] = param_info['enum']
269
+ if 'default' not in param_info: # If no default, might be required
270
+ required.append(param_name)
271
+
272
+ properties[param_name] = prop
273
+
274
+ tool_schema = {
275
+ "name": tool.get('name', ''),
276
+ "description": tool.get('description', ''),
277
+ "parameters": {
278
+ "type": "object",
279
+ "properties": properties
280
+ }
281
+ }
282
+ if required:
283
+ tool_schema["parameters"]["required"] = required
284
+
285
+ enabled_tools.append(tool_schema)
286
+
287
+ # Get tool calling from answers
288
+ tool_calling = None
289
+ if item.get('answers') and len(item['answers']) > 0:
290
+ answer = item['answers'][0]
291
+ tool_calling = {
292
+ "name": answer.get('name', ''),
293
+ "arguments": answer.get('arguments', {})
294
+ }
295
+
296
+ return {
297
+ "query": item.get('query', ''),
298
+ "enabled_tools": enabled_tools,
299
+ "tool_calling": tool_calling
300
+ }
301
+
302
+
303
  @app.get("/api/examples")
304
  async def get_examples():
305
+ """Return random example inputs from dataset, grouped by label"""
306
+ global dataset_by_label
307
+
308
+ # Load dataset if not already loaded
309
+ if dataset_by_label is None:
310
+ load_dataset()
311
+
312
+ examples = []
313
+
314
+ # If dataset is loaded, get random examples from each label
315
+ if dataset_by_label:
316
+ # Get one random example from each label
317
+ for label in LABEL_MAP.values():
318
+ if label in dataset_by_label and len(dataset_by_label[label]) > 0:
319
+ # Randomly select an example from this label
320
+ random_example = random.choice(dataset_by_label[label])
321
+
322
+ # Convert to API format
323
+ try:
324
+ api_format = convert_dataset_example_to_api_format(random_example)
325
+
326
+ # Create example entry
327
+ example_entry = {
328
+ "name": f"{label} Example",
329
+ "description": f"Example of {label.replace('_', ' ').title()}",
330
+ "data": api_format
331
  }
332
+ examples.append(example_entry)
333
+ except Exception as e:
334
+ print(f"Error converting example for label {label}: {e}")
335
+ continue
336
+ else:
337
+ # Fallback to hardcoded examples if dataset not available
338
+ examples = [
339
+ {
340
+ "name": "Correct Example",
341
+ "description": "A properly formed tool call",
342
+ "data": {
343
+ "query": "What's the weather in New York?",
344
+ "enabled_tools": [
345
+ {
346
+ "name": "get_weather",
347
+ "description": "Get current weather for a location",
348
+ "parameters": {
349
+ "type": "object",
350
+ "properties": {
351
+ "location": {"type": "string"},
352
+ "units": {"type": "string", "enum": ["celsius", "fahrenheit"]}
353
+ },
354
+ "required": ["location"]
355
  }
356
  }
357
+ ],
358
+ "tool_calling": {
359
+ "name": "get_weather",
360
+ "arguments": {
361
+ "location": "New York",
362
+ "units": "fahrenheit"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  }
364
  }
 
 
 
 
 
 
 
365
  }
366
  }
367
+ ]
368
+
369
+ # Shuffle examples to randomize order
370
+ random.shuffle(examples)
371
+
372
  return {"examples": examples}
373
 
374