daoqm123 commited on
Commit
877b44a
·
1 Parent(s): e7916fb

Deploy FastAPI backend

Browse files
Files changed (4) hide show
  1. Dockerfile +29 -0
  2. README.md +62 -5
  3. main.py +285 -0
  4. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y \
5
+ build-essential \
6
+ && rm -rf /var/lib/apt/lists/*
7
+
8
+ # Set working directory
9
+ WORKDIR /app
10
+
11
+ # Copy requirements
12
+ COPY requirements.txt .
13
+
14
+ # Install Python dependencies
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy application code
18
+ COPY main.py .
19
+
20
+ # Expose port (HuggingFace Spaces uses port 7860)
21
+ EXPOSE 7860
22
+
23
+ # Set environment variables
24
+ ENV PORT=7860
25
+ ENV PYTHONUNBUFFERED=1
26
+
27
+ # Run the application
28
+ CMD ["python", "main.py"]
29
+
README.md CHANGED
@@ -1,10 +1,67 @@
1
  ---
2
- title: Llm Error Classifier Api
3
- emoji: 👀
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: docker
 
 
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: LLM Error Classifier API
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
+ sdk_version: 20.10.24
8
+ app_file: main.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # LLM Error Classifier API
14
+
15
+ FastAPI backend serving the fine-tuned Llama-3.2-3B model for tool-use error classification.
16
+
17
+ ## API Endpoints
18
+
19
+ - `POST /api/classify` - Classify a tool call
20
+ - `GET /api/examples` - Get example inputs
21
+ - `GET /health` - Health check
22
+
23
+ ## Model
24
+
25
+ Model: `daoqm123/llm-error-classifier`
26
+
27
+ ## Usage
28
+
29
+ The API will automatically load the model from HuggingFace Hub on startup.
30
+
31
+ ## Deploying to Hugging Face Spaces
32
+
33
+ 1. **Create a Space**
34
+ - Go to https://huggingface.co/spaces/new and choose `Docker` as the SDK (this repo already contains a Dockerfile).
35
+ - Give the space a name such as `llm-error-classifier-api` and select the desired hardware (CPU is fine unless you need GPU acceleration).
36
+ - After the space is created, copy the Git commands shown in the “Files” tab; you will push the contents of this `api/` folder there.
37
+
38
+ 2. **Authenticate locally**
39
+ ```bash
40
+ pip install -U "huggingface_hub[cli]"
41
+ huggingface-cli login
42
+ ```
43
+ Use a write token from https://huggingface.co/settings/tokens.
44
+
45
+ 3. **Push the backend code**
46
+ ```bash
47
+ cd /work/cssema416/202610/12/llm-frontend-for-quang\ \(1\)/api
48
+ rm -rf .git
49
+ git init
50
+ git remote add origin https://huggingface.co/spaces/<username>/<space-name>
51
+ git add .
52
+ git commit -m "Deploy FastAPI backend"
53
+ git push origin main
54
+ ```
55
+ Replace `<username>` and `<space-name>` with your actual values. Hugging Face will build the Docker image automatically; the server becomes available at `https://<space-name>.<username>.hf.space`.
56
+
57
+ 4. **Configure runtime behavior (optional)**
58
+ - Set a custom `MODEL_PATH` or other environment variables from the “Settings → Repository secrets” tab inside the Space.
59
+ - If you need GPU, request the proper hardware tier in the hardware selector.
60
+
61
+ 5. **Wire up the Vercel frontend**
62
+ - In `frontend/lib/api.ts` the app reads `process.env.NEXT_PUBLIC_API_URL`.
63
+ - On Vercel, set `NEXT_PUBLIC_API_URL=https://<space-name>.<username>.hf.space` (no trailing slash) and redeploy the frontend so calls go directly to the Space backend.
64
+
65
+ 6. **Verify**
66
+ - Open the Space URL to confirm the FastAPI app is live (you should see the default 404 JSON from FastAPI or add a `/health` suffix).
67
+ - Visit your Vercel deployment and ensure inference requests succeed using the new backend endpoint.
main.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Backend for LLM Tool-Use Error Classifier
3
+ Serves predictions from the fine-tuned Llama-3.2-3B model
4
+ """
5
+
6
+ from contextlib import asynccontextmanager
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from typing import Dict, Any, List
11
+ import json
12
+ import os
13
+ import time
14
+ import torch
15
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
16
+
17
+ # Global model and tokenizer
18
+ model = None
19
+ tokenizer = None
20
+ device = None
21
+
22
+ os.environ["CUDA_VISIBLE_DEVICES"] = "7"
23
+ @asynccontextmanager
24
+ async def lifespan(app: FastAPI):
25
+ """Lifespan context manager for startup and shutdown events"""
26
+ global model, tokenizer, device
27
+
28
+ # Startup
29
+ print("Loading model...")
30
+ # Get model path from environment variable, fallback to HuggingFace or local path
31
+ model_path = os.getenv("MODEL_PATH", "daoqm123/llm-error-classifier")
32
+ print(f"Model path: {model_path}")
33
+
34
+ # Determine device and dtype
35
+ if torch.cuda.is_available():
36
+ device = torch.device("cuda")
37
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
38
+ print(f"Using GPU with dtype: {dtype}")
39
+ else:
40
+ device = torch.device("cpu")
41
+ dtype = torch.float32
42
+ print("Using CPU")
43
+
44
+ # Load tokenizer and model
45
+ # Supports both local paths and HuggingFace hub paths (e.g., "daoqm123/llm-error-classifier")
46
+ print(f"Loading tokenizer from: {model_path}")
47
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
48
+
49
+ print(f"Loading model from: {model_path}")
50
+ model = AutoModelForSequenceClassification.from_pretrained(
51
+ model_path,
52
+ torch_dtype=dtype,
53
+ device_map="auto" if torch.cuda.is_available() else None
54
+ )
55
+
56
+ if not torch.cuda.is_available():
57
+ model = model.to(device)
58
+
59
+ model.eval()
60
+ print("Model loaded successfully!")
61
+
62
+ yield # Application runs here
63
+
64
+ # Shutdown (if needed)
65
+ # Cleanup code can go here
66
+
67
+
68
+ app = FastAPI(title="LLM Error Classifier API", version="1.0.0", lifespan=lifespan)
69
+
70
+ # Enable CORS for frontend
71
+ app.add_middleware(
72
+ CORSMiddleware,
73
+ allow_origins=["*"], # In production, specify exact origins
74
+ allow_credentials=True,
75
+ allow_methods=["*"],
76
+ allow_headers=["*"],
77
+ )
78
+
79
+ # Label mapping
80
+ LABEL_MAP = {
81
+ 0: "Correct",
82
+ 1: "No_Tool_Available",
83
+ 2: "Incorrect_Function_Name",
84
+ 3: "Incorrect_Argument_Type",
85
+ 4: "Wrong_Syntax",
86
+ 5: "Wrong_Tool",
87
+ 6: "Incorrect_Argument_Value",
88
+ 7: "Incorrect_Argument_Name"
89
+ }
90
+
91
+ # Color mapping for frontend
92
+ LABEL_COLORS = {
93
+ "Correct": "#10B981",
94
+ "No_Tool_Available": "#F59E0B",
95
+ "Incorrect_Function_Name": "#EF4444",
96
+ "Incorrect_Argument_Name": "#EC4899",
97
+ "Incorrect_Argument_Value": "#8B5CF6",
98
+ "Incorrect_Argument_Type": "#3B82F6",
99
+ "Wrong_Tool": "#F97316",
100
+ "Wrong_Syntax": "#DC2626"
101
+ }
102
+
103
+ class ClassificationRequest(BaseModel):
104
+ """Request body for classification endpoint"""
105
+ query: str
106
+ enabled_tools: List[Dict[str, Any]]
107
+ tool_calling: Dict[str, Any]
108
+
109
+
110
+ class ClassificationResponse(BaseModel):
111
+ """Response from classification endpoint"""
112
+ label: str
113
+ confidence: float
114
+ all_probabilities: Dict[str, float]
115
+ processing_time_ms: int
116
+ category_color: str
117
+
118
+
119
+
120
+
121
+ @app.get("/health")
122
+ async def health_check():
123
+ """Health check endpoint"""
124
+ return {
125
+ "status": "ok",
126
+ "model_loaded": model is not None,
127
+ "device": str(device) if device else "not initialized"
128
+ }
129
+
130
+
131
+ @app.post("/api/classify", response_model=ClassificationResponse)
132
+ async def classify(request: ClassificationRequest):
133
+ """
134
+ Classify a tool call as correct or identify the error type
135
+ """
136
+ if model is None or tokenizer is None:
137
+ raise HTTPException(status_code=503, detail="Model not loaded")
138
+
139
+ start_time = time.time()
140
+
141
+ try:
142
+ # Format input as JSON string (same format as training)
143
+ input_data = {
144
+ "query": request.query,
145
+ "enabled_tools": request.enabled_tools,
146
+ "tool_calling": request.tool_calling
147
+ }
148
+ input_text = json.dumps(input_data)
149
+
150
+ # Tokenize
151
+ inputs = tokenizer(
152
+ input_text,
153
+ return_tensors="pt",
154
+ truncation=True,
155
+ max_length=512,
156
+ padding=True
157
+ )
158
+
159
+ # Move to device
160
+ inputs = {k: v.to(device) for k, v in inputs.items()}
161
+
162
+ # Get prediction
163
+ with torch.no_grad():
164
+ outputs = model(**inputs)
165
+ logits = outputs.logits
166
+ probs = torch.softmax(logits, dim=-1)[0]
167
+ pred_idx = torch.argmax(probs).item()
168
+ confidence = probs[pred_idx].item()
169
+
170
+ # Get all probabilities
171
+ all_probs = {LABEL_MAP[i]: float(probs[i]) for i in range(len(probs))}
172
+
173
+ # Get predicted label
174
+ predicted_label = LABEL_MAP[pred_idx]
175
+
176
+ # Calculate processing time
177
+ processing_time_ms = int((time.time() - start_time) * 1000)
178
+
179
+ return ClassificationResponse(
180
+ label=predicted_label,
181
+ confidence=confidence,
182
+ all_probabilities=all_probs,
183
+ processing_time_ms=processing_time_ms,
184
+ category_color=LABEL_COLORS.get(predicted_label, "#6B7280")
185
+ )
186
+
187
+ except Exception as e:
188
+ raise HTTPException(status_code=500, detail=f"Classification error: {str(e)}")
189
+
190
+
191
+ @app.get("/api/examples")
192
+ async def get_examples():
193
+ """Return example inputs for testing"""
194
+ examples = [
195
+ {
196
+ "name": "Correct Example",
197
+ "description": "A properly formed tool call",
198
+ "data": {
199
+ "query": "What's the weather in New York?",
200
+ "enabled_tools": [
201
+ {
202
+ "name": "get_weather",
203
+ "description": "Get current weather for a location",
204
+ "parameters": {
205
+ "type": "object",
206
+ "properties": {
207
+ "location": {"type": "string"},
208
+ "units": {"type": "string", "enum": ["celsius", "fahrenheit"]}
209
+ },
210
+ "required": ["location"]
211
+ }
212
+ }
213
+ ],
214
+ "tool_calling": {
215
+ "name": "get_weather",
216
+ "arguments": {
217
+ "location": "New York",
218
+ "units": "fahrenheit"
219
+ }
220
+ }
221
+ }
222
+ },
223
+ {
224
+ "name": "Wrong Function Name",
225
+ "description": "Tool call uses incorrect function name",
226
+ "data": {
227
+ "query": "Calculate 25 * 4",
228
+ "enabled_tools": [
229
+ {
230
+ "name": "calculator",
231
+ "description": "Perform calculations",
232
+ "parameters": {
233
+ "type": "object",
234
+ "properties": {
235
+ "expression": {"type": "string"}
236
+ }
237
+ }
238
+ }
239
+ ],
240
+ "tool_calling": {
241
+ "name": "calculate", # Wrong name!
242
+ "arguments": {
243
+ "expression": "25 * 4"
244
+ }
245
+ }
246
+ }
247
+ },
248
+ {
249
+ "name": "Incorrect Argument Type",
250
+ "description": "Argument has wrong data type",
251
+ "data": {
252
+ "query": "Set a reminder for 3pm",
253
+ "enabled_tools": [
254
+ {
255
+ "name": "set_reminder",
256
+ "description": "Create a reminder",
257
+ "parameters": {
258
+ "type": "object",
259
+ "properties": {
260
+ "time": {"type": "string"},
261
+ "message": {"type": "string"}
262
+ }
263
+ }
264
+ }
265
+ ],
266
+ "tool_calling": {
267
+ "name": "set_reminder",
268
+ "arguments": {
269
+ "time": 1500, # Should be string!
270
+ "message": "Meeting"
271
+ }
272
+ }
273
+ }
274
+ }
275
+ ]
276
+
277
+ return {"examples": examples}
278
+
279
+
280
+ if __name__ == "__main__":
281
+ import uvicorn
282
+ # HuggingFace Spaces uses port 7860, but allow override via environment variable
283
+ port = int(os.getenv("PORT", 7860))
284
+ # Use 0.0.0.0 to allow external connections
285
+ uvicorn.run(app, host="0.0.0.0", port=port)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ pydantic==2.5.0
4
+ torch>=2.0.0
5
+ transformers>=4.35.0