""" FastAPI Backend for LLM Tool-Use Error Classifier Serves predictions from the fine-tuned Llama-3.2-3B model """ from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Dict, Any, List import json import os import time import torch import random from transformers import AutoTokenizer, AutoModelForSequenceClassification from collections import defaultdict from pathlib import Path from huggingface_hub import hf_hub_download # Global model and tokenizer model = None tokenizer = None device = None dataset_by_label = None dataset_path = None # os.environ["CUDA_VISIBLE_DEVICES"] = "7" @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" global model, tokenizer, device # Startup print("Loading model...") # Get model path from environment variable, fallback to HuggingFace or local path model_path = "daoqm123/llm-error-classifier" print(f"Model path: {model_path}") # Determine device and dtype if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 print(f"Using GPU with dtype: {dtype}") else: device = torch.device("cpu") dtype = torch.float32 print("Using CPU") # Load tokenizer and model # Supports both local paths and HuggingFace hub paths (e.g., "daoqm123/llm-error-classifier") print(f"Loading tokenizer from: {model_path}") tokenizer = AutoTokenizer.from_pretrained(model_path) print(f"Loading model from: {model_path}") model = AutoModelForSequenceClassification.from_pretrained( model_path, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None ) if not torch.cuda.is_available(): model = model.to(device) model.eval() print("Model loaded successfully!") # Load dataset for examples print("Loading dataset for examples...") load_dataset() yield # Application runs here # Shutdown (if needed) # Cleanup code can go here app = FastAPI(title="LLM Error Classifier API", version="1.0.0", lifespan=lifespan) # Enable CORS for frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify exact origins allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Label mapping LABEL_MAP = { 0: "Correct", 1: "No_Tool_Available", 2: "Incorrect_Function_Name", 3: "Incorrect_Argument_Type", 4: "Wrong_Syntax", 5: "Wrong_Tool", 6: "Incorrect_Argument_Value", 7: "Incorrect_Argument_Name" } # Color mapping for frontend LABEL_COLORS = { "Correct": "#10B981", "No_Tool_Available": "#F59E0B", "Incorrect_Function_Name": "#EF4444", "Incorrect_Argument_Name": "#EC4899", "Incorrect_Argument_Value": "#8B5CF6", "Incorrect_Argument_Type": "#3B82F6", "Wrong_Tool": "#F97316", "Wrong_Syntax": "#DC2626" } class ClassificationRequest(BaseModel): """Request body for classification endpoint""" query: str enabled_tools: List[Dict[str, Any]] tool_calling: Dict[str, Any] class ClassificationResponse(BaseModel): """Response from classification endpoint""" label: str confidence: float all_probabilities: Dict[str, float] processing_time_ms: int category_color: str @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "ok", "model_loaded": model is not None, "device": str(device) if device else "not initialized" } @app.post("/api/classify", response_model=ClassificationResponse) async def classify(request: ClassificationRequest): """ Classify a tool call as correct or identify the error type """ if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") start_time = time.time() try: # Format input as JSON string (same format as training) input_data = { "query": request.query, "enabled_tools": request.enabled_tools, "tool_calling": request.tool_calling } input_text = json.dumps(input_data) # Tokenize inputs = tokenizer( input_text, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} # Get prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.softmax(logits, dim=-1)[0] pred_idx = torch.argmax(probs).item() confidence = probs[pred_idx].item() # Get all probabilities all_probs = {LABEL_MAP[i]: float(probs[i]) for i in range(len(probs))} # Get predicted label predicted_label = LABEL_MAP[pred_idx] # Calculate processing time processing_time_ms = int((time.time() - start_time) * 1000) return ClassificationResponse( label=predicted_label, confidence=confidence, all_probabilities=all_probs, processing_time_ms=processing_time_ms, category_color=LABEL_COLORS.get(predicted_label, "#6B7280") ) except Exception as e: raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}") def load_dataset(): """Load dataset and group examples by label""" global dataset_by_label, dataset_path DATASET_FILE = "xlam_function_calling_60k_processed_with_ground_truth.json" dataset_path = Path( hf_hub_download( repo_type="space", repo_id="daoqm123/llm-error-classifier-api", filename=DATASET_FILE, ) ) if not dataset_path: print("Warning: Dataset file not found. Using hardcoded examples.") print(f"Checked paths: {possible_paths}") dataset_by_label = None return None try: print(f"Loading dataset from: {dataset_path}") with open(dataset_path, 'r') as f: data = json.load(f) # Group examples by ground_truth label - MUST declare global here! dataset_by_label = defaultdict(list) for item in data: label = item.get('ground_truth', 'Unknown') if label in LABEL_MAP.values(): dataset_by_label[label].append(item) print(f"Loaded {len(data)} examples. Examples per label: {dict((k, len(v)) for k, v in dataset_by_label.items())}") print(f"Global dataset_by_label is now set: {dataset_by_label is not None}") return dataset_by_label except Exception as e: print(f"Error loading dataset: {e}") import traceback traceback.print_exc() dataset_by_label = None return None def get_fallback_examples() -> List[Dict[str, Any]]: """Return hardcoded examples for all label types""" return [ { "name": "Correct Example", "description": "A properly formed tool call", "expected_output": "Correct", "data": { "query": "What's the weather in New York?", "enabled_tools": [ { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string"}, "units": {"type": "string", "enum": ["celsius", "fahrenheit"]} }, "required": ["location"] } } ], "tool_calling": { "name": "get_weather", "arguments": { "location": "New York", "units": "fahrenheit" } } } }, { "name": "Incorrect Function Name Example", "description": "Tool call uses incorrect function name", "expected_output": "Incorrect_Function_Name", "data": { "query": "Calculate 25 * 4", "enabled_tools": [ { "name": "calculator", "description": "Perform calculations", "parameters": { "type": "object", "properties": { "expression": {"type": "string"} }, "required": ["expression"] } } ], "tool_calling": { "name": "calculate", # Wrong name! "arguments": { "expression": "25 * 4" } } } }, { "name": "Incorrect Argument Type Example", "description": "Argument has wrong data type", "expected_output": "Incorrect_Argument_Type", "data": { "query": "Set a reminder for 3pm", "enabled_tools": [ { "name": "set_reminder", "description": "Create a reminder", "parameters": { "type": "object", "properties": { "time": {"type": "string"}, "message": {"type": "string"} }, "required": ["time", "message"] } } ], "tool_calling": { "name": "set_reminder", "arguments": { "time": 1500, # Should be string! "message": "Meeting" } } } }, { "name": "Incorrect Argument Name Example", "description": "Argument name doesn't match tool parameters", "expected_output": "Incorrect_Argument_Name", "data": { "query": "Send an email to john@example.com", "enabled_tools": [ { "name": "send_email", "description": "Send an email message", "parameters": { "type": "object", "properties": { "recipient": {"type": "string"}, "subject": {"type": "string"}, "body": {"type": "string"} }, "required": ["recipient", "subject"] } } ], "tool_calling": { "name": "send_email", "arguments": { "to": "john@example.com", # Wrong name! Should be "recipient" "subject": "Hello", "body": "Test message" } } } }, { "name": "Incorrect Argument Value Example", "description": "Argument value doesn't match expected format", "expected_output": "Incorrect_Argument_Value", "data": { "query": "Get weather in Celsius", "enabled_tools": [ { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string"}, "units": {"type": "string", "enum": ["celsius", "fahrenheit", "kelvin"]} }, "required": ["location"] } } ], "tool_calling": { "name": "get_weather", "arguments": { "location": "London", "units": "centigrade" # Wrong value! Not in enum } } } }, { "name": "Wrong Tool Example", "description": "Wrong tool selected for the task", "expected_output": "Wrong_Tool", "data": { "query": "What's the weather in Paris?", "enabled_tools": [ { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string"} }, "required": ["location"] } }, { "name": "search_web", "description": "Search the web for information", "parameters": { "type": "object", "properties": { "query": {"type": "string"} }, "required": ["query"] } } ], "tool_calling": { "name": "search_web", # Wrong tool! Should use get_weather "arguments": { "query": "weather in Paris" } } } }, { "name": "Wrong Syntax Example", "description": "Tool call syntax is malformed", "expected_output": "Wrong_Syntax", "data": { "query": "Calculate 10 + 5", "enabled_tools": [ { "name": "calculator", "description": "Perform mathematical calculations", "parameters": { "type": "object", "properties": { "expression": {"type": "string"} }, "required": ["expression"] } } ], "tool_calling": { "name": "calculator", "arguments": { "expression": ["10", "+", "5"] # Wrong type! Should be string } } } }, { "name": "No Tool Available Example", "description": "No matching tool exists for the request", "expected_output": "No_Tool_Available", "data": { "query": "Translate 'Hello' to Spanish", "enabled_tools": [ { "name": "get_weather", "description": "Get current weather for a location", "parameters": { "type": "object", "properties": { "location": {"type": "string"} }, "required": ["location"] } }, { "name": "calculator", "description": "Perform mathematical calculations", "parameters": { "type": "object", "properties": { "expression": {"type": "string"} }, "required": ["expression"] } } ], "tool_calling": { "name": "translate", # Tool doesn't exist in enabled_tools! "arguments": { "text": "Hello", "target_language": "Spanish" } } } } ] def convert_dataset_example_to_api_format(item: Dict[str, Any]) -> Dict[str, Any]: """Convert dataset example to API format""" # Convert tools format enabled_tools = [] for tool in item.get('tools', []): # Convert parameters from dict format to JSON Schema format properties = {} required = [] tool_params = tool.get('parameters', {}) if isinstance(tool_params, dict): for param_name, param_info in tool_params.items(): if isinstance(param_info, dict): param_type = param_info.get('type', 'string') # Map Python types to JSON types type_mapping = { 'str': 'string', 'int': 'integer', 'float': 'number', 'bool': 'boolean', 'list': 'array', 'dict': 'object' } json_type = type_mapping.get(param_type, 'string') prop = {"type": json_type} if 'description' in param_info: prop['description'] = param_info['description'] if 'enum' in param_info: prop['enum'] = param_info['enum'] if 'default' not in param_info: # If no default, might be required required.append(param_name) properties[param_name] = prop tool_schema = { "name": tool.get('name', ''), "description": tool.get('description', ''), "parameters": { "type": "object", "properties": properties } } if required: tool_schema["parameters"]["required"] = required enabled_tools.append(tool_schema) # Get tool calling from answers tool_calling = None if item.get('answers') and len(item['answers']) > 0: answer = item['answers'][0] tool_calling = { "name": answer.get('name', ''), "arguments": answer.get('arguments', {}) } return { "query": item.get('query', ''), "enabled_tools": enabled_tools, "tool_calling": tool_calling } @app.get("/api/examples") async def get_examples(): """Return random example inputs from dataset, grouped by label""" global dataset_by_label # Load dataset if not already loaded if dataset_by_label is None: print("Dataset not loaded, attempting to load...") result = load_dataset() if result is None: print("Failed to load dataset, using fallback examples") examples = [] # If dataset is loaded, get random examples from each label if dataset_by_label and len(dataset_by_label) > 0: print(f"Using dataset with {len(dataset_by_label)} label categories") # Get one random example from each label for label in LABEL_MAP.values(): if label in dataset_by_label and len(dataset_by_label[label]) > 0: # Randomly select an example from this label random_example = random.choice(dataset_by_label[label]) # Convert to API format try: api_format = convert_dataset_example_to_api_format(random_example) # Create example entry with expected_output (ground_truth) example_entry = { "name": f"{label} Example", "description": f"Example of {label.replace('_', ' ').title()}", "data": api_format, "expected_output": label # Add ground truth label } examples.append(example_entry) except Exception as e: print(f"Error converting example for label {label}: {e}") import traceback traceback.print_exc() continue print(f"Generated {len(examples)} random examples from dataset") else: # Fallback to hardcoded examples for ALL labels if dataset not available print("Using hardcoded fallback examples for all labels") examples = get_fallback_examples() # Shuffle examples to randomize order random.shuffle(examples) print(f"Returning {len(examples)} examples (shuffled)") return {"examples": examples} if __name__ == "__main__": import uvicorn # HuggingFace Spaces uses port 7860, but allow override via environment variable port = int(os.getenv("PORT", 7860)) # Use 0.0.0.0 to allow external connections uvicorn.run(app, host="0.0.0.0", port=port)