Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Backend for LLM Tool-Use Error Classifier | |
| Serves predictions from the fine-tuned Llama-3.2-3B model | |
| """ | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Dict, Any, List | |
| import json | |
| import os | |
| import time | |
| import torch | |
| import random | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| # Global model and tokenizer | |
| model = None | |
| tokenizer = None | |
| device = None | |
| dataset_by_label = None | |
| dataset_path = None | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "7" | |
| async def lifespan(app: FastAPI): | |
| """Lifespan context manager for startup and shutdown events""" | |
| global model, tokenizer, device | |
| # Startup | |
| print("Loading model...") | |
| # Get model path from environment variable, fallback to HuggingFace or local path | |
| model_path = "daoqm123/llm-error-classifier" | |
| print(f"Model path: {model_path}") | |
| # Determine device and dtype | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 | |
| print(f"Using GPU with dtype: {dtype}") | |
| else: | |
| device = torch.device("cpu") | |
| dtype = torch.float32 | |
| print("Using CPU") | |
| # Load tokenizer and model | |
| # Supports both local paths and HuggingFace hub paths (e.g., "daoqm123/llm-error-classifier") | |
| print(f"Loading tokenizer from: {model_path}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| print(f"Loading model from: {model_path}") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_path, | |
| torch_dtype=dtype, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| # Load dataset for examples | |
| print("Loading dataset for examples...") | |
| load_dataset() | |
| yield # Application runs here | |
| # Shutdown (if needed) | |
| # Cleanup code can go here | |
| app = FastAPI(title="LLM Error Classifier API", version="1.0.0", lifespan=lifespan) | |
| # Enable CORS for frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify exact origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Label mapping | |
| LABEL_MAP = { | |
| 0: "Correct", | |
| 1: "No_Tool_Available", | |
| 2: "Incorrect_Function_Name", | |
| 3: "Incorrect_Argument_Type", | |
| 4: "Wrong_Syntax", | |
| 5: "Wrong_Tool", | |
| 6: "Incorrect_Argument_Value", | |
| 7: "Incorrect_Argument_Name" | |
| } | |
| # Color mapping for frontend | |
| LABEL_COLORS = { | |
| "Correct": "#10B981", | |
| "No_Tool_Available": "#F59E0B", | |
| "Incorrect_Function_Name": "#EF4444", | |
| "Incorrect_Argument_Name": "#EC4899", | |
| "Incorrect_Argument_Value": "#8B5CF6", | |
| "Incorrect_Argument_Type": "#3B82F6", | |
| "Wrong_Tool": "#F97316", | |
| "Wrong_Syntax": "#DC2626" | |
| } | |
| class ClassificationRequest(BaseModel): | |
| """Request body for classification endpoint""" | |
| query: str | |
| enabled_tools: List[Dict[str, Any]] | |
| tool_calling: Dict[str, Any] | |
| class ClassificationResponse(BaseModel): | |
| """Response from classification endpoint""" | |
| label: str | |
| confidence: float | |
| all_probabilities: Dict[str, float] | |
| processing_time_ms: int | |
| category_color: str | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "ok", | |
| "model_loaded": model is not None, | |
| "device": str(device) if device else "not initialized" | |
| } | |
| async def classify(request: ClassificationRequest): | |
| """ | |
| Classify a tool call as correct or identify the error type | |
| """ | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| start_time = time.time() | |
| try: | |
| # Format input as JSON string (same format as training) | |
| input_data = { | |
| "query": request.query, | |
| "enabled_tools": request.enabled_tools, | |
| "tool_calling": request.tool_calling | |
| } | |
| input_text = json.dumps(input_data) | |
| # Tokenize | |
| inputs = tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| # Move to device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Get prediction | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.softmax(logits, dim=-1)[0] | |
| pred_idx = torch.argmax(probs).item() | |
| confidence = probs[pred_idx].item() | |
| # Get all probabilities | |
| all_probs = {LABEL_MAP[i]: float(probs[i]) for i in range(len(probs))} | |
| # Get predicted label | |
| predicted_label = LABEL_MAP[pred_idx] | |
| # Calculate processing time | |
| processing_time_ms = int((time.time() - start_time) * 1000) | |
| return ClassificationResponse( | |
| label=predicted_label, | |
| confidence=confidence, | |
| all_probabilities=all_probs, | |
| processing_time_ms=processing_time_ms, | |
| category_color=LABEL_COLORS.get(predicted_label, "#6B7280") | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}") | |
| def load_dataset(): | |
| """Load dataset and group examples by label""" | |
| global dataset_by_label, dataset_path | |
| DATASET_FILE = "xlam_function_calling_60k_processed_with_ground_truth.json" | |
| dataset_path = Path( | |
| hf_hub_download( | |
| repo_type="space", | |
| repo_id="daoqm123/llm-error-classifier-api", | |
| filename=DATASET_FILE, | |
| ) | |
| ) | |
| if not dataset_path: | |
| print("Warning: Dataset file not found. Using hardcoded examples.") | |
| print(f"Checked paths: {possible_paths}") | |
| dataset_by_label = None | |
| return None | |
| try: | |
| print(f"Loading dataset from: {dataset_path}") | |
| with open(dataset_path, 'r') as f: | |
| data = json.load(f) | |
| # Group examples by ground_truth label - MUST declare global here! | |
| dataset_by_label = defaultdict(list) | |
| for item in data: | |
| label = item.get('ground_truth', 'Unknown') | |
| if label in LABEL_MAP.values(): | |
| dataset_by_label[label].append(item) | |
| print(f"Loaded {len(data)} examples. Examples per label: {dict((k, len(v)) for k, v in dataset_by_label.items())}") | |
| print(f"Global dataset_by_label is now set: {dataset_by_label is not None}") | |
| return dataset_by_label | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| dataset_by_label = None | |
| return None | |
| def get_fallback_examples() -> List[Dict[str, Any]]: | |
| """Return hardcoded examples for all label types""" | |
| return [ | |
| { | |
| "name": "Correct Example", | |
| "description": "A properly formed tool call", | |
| "expected_output": "Correct", | |
| "data": { | |
| "query": "What's the weather in New York?", | |
| "enabled_tools": [ | |
| { | |
| "name": "get_weather", | |
| "description": "Get current weather for a location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": {"type": "string"}, | |
| "units": {"type": "string", "enum": ["celsius", "fahrenheit"]} | |
| }, | |
| "required": ["location"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "get_weather", | |
| "arguments": { | |
| "location": "New York", | |
| "units": "fahrenheit" | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "Incorrect Function Name Example", | |
| "description": "Tool call uses incorrect function name", | |
| "expected_output": "Incorrect_Function_Name", | |
| "data": { | |
| "query": "Calculate 25 * 4", | |
| "enabled_tools": [ | |
| { | |
| "name": "calculator", | |
| "description": "Perform calculations", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "expression": {"type": "string"} | |
| }, | |
| "required": ["expression"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "calculate", # Wrong name! | |
| "arguments": { | |
| "expression": "25 * 4" | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "Incorrect Argument Type Example", | |
| "description": "Argument has wrong data type", | |
| "expected_output": "Incorrect_Argument_Type", | |
| "data": { | |
| "query": "Set a reminder for 3pm", | |
| "enabled_tools": [ | |
| { | |
| "name": "set_reminder", | |
| "description": "Create a reminder", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "time": {"type": "string"}, | |
| "message": {"type": "string"} | |
| }, | |
| "required": ["time", "message"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "set_reminder", | |
| "arguments": { | |
| "time": 1500, # Should be string! | |
| "message": "Meeting" | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "Incorrect Argument Name Example", | |
| "description": "Argument name doesn't match tool parameters", | |
| "expected_output": "Incorrect_Argument_Name", | |
| "data": { | |
| "query": "Send an email to john@example.com", | |
| "enabled_tools": [ | |
| { | |
| "name": "send_email", | |
| "description": "Send an email message", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "recipient": {"type": "string"}, | |
| "subject": {"type": "string"}, | |
| "body": {"type": "string"} | |
| }, | |
| "required": ["recipient", "subject"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "send_email", | |
| "arguments": { | |
| "to": "john@example.com", # Wrong name! Should be "recipient" | |
| "subject": "Hello", | |
| "body": "Test message" | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "Incorrect Argument Value Example", | |
| "description": "Argument value doesn't match expected format", | |
| "expected_output": "Incorrect_Argument_Value", | |
| "data": { | |
| "query": "Get weather in Celsius", | |
| "enabled_tools": [ | |
| { | |
| "name": "get_weather", | |
| "description": "Get current weather for a location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": {"type": "string"}, | |
| "units": {"type": "string", "enum": ["celsius", "fahrenheit", "kelvin"]} | |
| }, | |
| "required": ["location"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "get_weather", | |
| "arguments": { | |
| "location": "London", | |
| "units": "centigrade" # Wrong value! Not in enum | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "Wrong Tool Example", | |
| "description": "Wrong tool selected for the task", | |
| "expected_output": "Wrong_Tool", | |
| "data": { | |
| "query": "What's the weather in Paris?", | |
| "enabled_tools": [ | |
| { | |
| "name": "get_weather", | |
| "description": "Get current weather for a location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": {"type": "string"} | |
| }, | |
| "required": ["location"] | |
| } | |
| }, | |
| { | |
| "name": "search_web", | |
| "description": "Search the web for information", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": {"type": "string"} | |
| }, | |
| "required": ["query"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "search_web", # Wrong tool! Should use get_weather | |
| "arguments": { | |
| "query": "weather in Paris" | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "Wrong Syntax Example", | |
| "description": "Tool call syntax is malformed", | |
| "expected_output": "Wrong_Syntax", | |
| "data": { | |
| "query": "Calculate 10 + 5", | |
| "enabled_tools": [ | |
| { | |
| "name": "calculator", | |
| "description": "Perform mathematical calculations", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "expression": {"type": "string"} | |
| }, | |
| "required": ["expression"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "calculator", | |
| "arguments": { | |
| "expression": ["10", "+", "5"] # Wrong type! Should be string | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "No Tool Available Example", | |
| "description": "No matching tool exists for the request", | |
| "expected_output": "No_Tool_Available", | |
| "data": { | |
| "query": "Translate 'Hello' to Spanish", | |
| "enabled_tools": [ | |
| { | |
| "name": "get_weather", | |
| "description": "Get current weather for a location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": {"type": "string"} | |
| }, | |
| "required": ["location"] | |
| } | |
| }, | |
| { | |
| "name": "calculator", | |
| "description": "Perform mathematical calculations", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "expression": {"type": "string"} | |
| }, | |
| "required": ["expression"] | |
| } | |
| } | |
| ], | |
| "tool_calling": { | |
| "name": "translate", # Tool doesn't exist in enabled_tools! | |
| "arguments": { | |
| "text": "Hello", | |
| "target_language": "Spanish" | |
| } | |
| } | |
| } | |
| } | |
| ] | |
| def convert_dataset_example_to_api_format(item: Dict[str, Any]) -> Dict[str, Any]: | |
| """Convert dataset example to API format""" | |
| # Convert tools format | |
| enabled_tools = [] | |
| for tool in item.get('tools', []): | |
| # Convert parameters from dict format to JSON Schema format | |
| properties = {} | |
| required = [] | |
| tool_params = tool.get('parameters', {}) | |
| if isinstance(tool_params, dict): | |
| for param_name, param_info in tool_params.items(): | |
| if isinstance(param_info, dict): | |
| param_type = param_info.get('type', 'string') | |
| # Map Python types to JSON types | |
| type_mapping = { | |
| 'str': 'string', | |
| 'int': 'integer', | |
| 'float': 'number', | |
| 'bool': 'boolean', | |
| 'list': 'array', | |
| 'dict': 'object' | |
| } | |
| json_type = type_mapping.get(param_type, 'string') | |
| prop = {"type": json_type} | |
| if 'description' in param_info: | |
| prop['description'] = param_info['description'] | |
| if 'enum' in param_info: | |
| prop['enum'] = param_info['enum'] | |
| if 'default' not in param_info: # If no default, might be required | |
| required.append(param_name) | |
| properties[param_name] = prop | |
| tool_schema = { | |
| "name": tool.get('name', ''), | |
| "description": tool.get('description', ''), | |
| "parameters": { | |
| "type": "object", | |
| "properties": properties | |
| } | |
| } | |
| if required: | |
| tool_schema["parameters"]["required"] = required | |
| enabled_tools.append(tool_schema) | |
| # Get tool calling from answers | |
| tool_calling = None | |
| if item.get('answers') and len(item['answers']) > 0: | |
| answer = item['answers'][0] | |
| tool_calling = { | |
| "name": answer.get('name', ''), | |
| "arguments": answer.get('arguments', {}) | |
| } | |
| return { | |
| "query": item.get('query', ''), | |
| "enabled_tools": enabled_tools, | |
| "tool_calling": tool_calling | |
| } | |
| async def get_examples(): | |
| """Return random example inputs from dataset, grouped by label""" | |
| global dataset_by_label | |
| # Load dataset if not already loaded | |
| if dataset_by_label is None: | |
| print("Dataset not loaded, attempting to load...") | |
| result = load_dataset() | |
| if result is None: | |
| print("Failed to load dataset, using fallback examples") | |
| examples = [] | |
| # If dataset is loaded, get random examples from each label | |
| if dataset_by_label and len(dataset_by_label) > 0: | |
| print(f"Using dataset with {len(dataset_by_label)} label categories") | |
| # Get one random example from each label | |
| for label in LABEL_MAP.values(): | |
| if label in dataset_by_label and len(dataset_by_label[label]) > 0: | |
| # Randomly select an example from this label | |
| random_example = random.choice(dataset_by_label[label]) | |
| # Convert to API format | |
| try: | |
| api_format = convert_dataset_example_to_api_format(random_example) | |
| # Create example entry with expected_output (ground_truth) | |
| example_entry = { | |
| "name": f"{label} Example", | |
| "description": f"Example of {label.replace('_', ' ').title()}", | |
| "data": api_format, | |
| "expected_output": label # Add ground truth label | |
| } | |
| examples.append(example_entry) | |
| except Exception as e: | |
| print(f"Error converting example for label {label}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| continue | |
| print(f"Generated {len(examples)} random examples from dataset") | |
| else: | |
| # Fallback to hardcoded examples for ALL labels if dataset not available | |
| print("Using hardcoded fallback examples for all labels") | |
| examples = get_fallback_examples() | |
| # Shuffle examples to randomize order | |
| random.shuffle(examples) | |
| print(f"Returning {len(examples)} examples (shuffled)") | |
| return {"examples": examples} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # HuggingFace Spaces uses port 7860, but allow override via environment variable | |
| port = int(os.getenv("PORT", 7860)) | |
| # Use 0.0.0.0 to allow external connections | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |