pierreramez commited on
Commit
65db64e
·
verified ·
1 Parent(s): 5ce7ec1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +559 -0
app.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced FastAPI Backend with Feedback Management
3
+ --------------------------------------------------
4
+ New endpoints for production continuous learning workflow:
5
+ - GET /download-feedback: Download feedback for training
6
+ - POST /clear-feedback: Clear feedback after training
7
+ - GET /correction-count: Monitor training readiness
8
+ - POST /reload-adapter: Hot reload new model without restart
9
+
10
+ Deploy to HuggingFace Spaces (FREE):
11
+ 1. Create new Space: "YourUsername/chatbot-api"
12
+ 2. Select: SDK = "Docker"
13
+ 3. Upload: app.py, requirements.txt, Dockerfile, README.md
14
+ """
15
+
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from pydantic import BaseModel
19
+ from typing import List, Optional, Dict
20
+ import json
21
+ import time
22
+ from pathlib import Path
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModelForCausalLM
25
+ from peft import PeftModel
26
+
27
+ app = FastAPI(
28
+ title="Personalized Chatbot API",
29
+ description="FastAPI backend for chatbot with HITL feedback and continuous learning",
30
+ version="2.0.0"
31
+ )
32
+
33
+ app.add_middleware(
34
+ CORSMiddleware,
35
+ allow_origins=["*"],
36
+ allow_credentials=True,
37
+ allow_methods=["*"],
38
+ allow_headers=["*"],
39
+ )
40
+
41
+
42
+ class ChatRequest(BaseModel):
43
+ message: str
44
+ history: Optional<List[Dict[str, str]]] = []
45
+ max_length: Optional[int] = 200
46
+ temperature: Optional[float] = 0.7
47
+
48
+
49
+ class FeedbackRequest(BaseModel):
50
+ user_input: str
51
+ model_reply: str
52
+ user_correction: str
53
+ reason: Optional[str] = "user_correction"
54
+
55
+
56
+ class ReloadAdapterRequest(BaseModel):
57
+ adapter_path: str
58
+
59
+
60
+ class ChatResponse(BaseModel):
61
+ reply: str
62
+ timestamp: float
63
+
64
+
65
+ class FeedbackResponse(BaseModel):
66
+ status: str
67
+ message: str
68
+
69
+
70
+ class StatsResponse(BaseModel):
71
+ total_interactions: int
72
+ corrections: int
73
+ accepted: int
74
+ correction_rate: float
75
+
76
+
77
+ class CorrectionCountResponse(BaseModel):
78
+ corrections: int
79
+ total: int
80
+ ready_to_train: bool
81
+
82
+
83
+ class DownloadFeedbackResponse(BaseModel):
84
+ content: str
85
+ count: int
86
+
87
+
88
+ class ModelManager:
89
+ """Singleton model manager to load model once and reuse."""
90
+ _instance = None
91
+ _model = None
92
+ _tokenizer = None
93
+ _device = None
94
+ _current_adapter = None
95
+
96
+ def __new__(cls):
97
+ if cls._instance is None:
98
+ cls._instance = super().__new__(cls)
99
+ return cls._instance
100
+
101
+ def initialize(
102
+ self,
103
+ model_name: str = "meta-llama/Llama-3.2-1B-Instruct",
104
+ adapter_path: Optional[str] = None,
105
+ use_4bit: bool = True
106
+ ):
107
+ """Initialize or reload model with new adapter."""
108
+
109
+ if adapter_path == self._current_adapter and self._model is not None:
110
+ print(f"Model already loaded with adapter: {adapter_path}")
111
+ return
112
+
113
+ print(f"Loading model: {model_name}")
114
+ if adapter_path:
115
+ print(f"With adapter: {adapter_path}")
116
+
117
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ print(f"Using device: {self._device}")
119
+
120
+ self._tokenizer = AutoTokenizer.from_pretrained(
121
+ model_name,
122
+ trust_remote_code=True
123
+ )
124
+
125
+ if self._tokenizer.pad_token is None:
126
+ self._tokenizer.pad_token = self._tokenizer.eos_token
127
+
128
+ if use_4bit and torch.cuda.is_available():
129
+ from transformers import BitsAndBytesConfig
130
+
131
+ bnb_config = BitsAndBytesConfig(
132
+ load_in_4bit=True,
133
+ bnb_4bit_quant_type="nf4",
134
+ bnb_4bit_compute_dtype=torch.float16,
135
+ bnb_4bit_use_double_quant=True,
136
+ )
137
+
138
+ base_model = AutoModelForCausalLM.from_pretrained(
139
+ model_name,
140
+ quantization_config=bnb_config,
141
+ device_map="auto",
142
+ trust_remote_code=True,
143
+ torch_dtype=torch.float16,
144
+ )
145
+ else:
146
+ base_model = AutoModelForCausalLM.from_pretrained(
147
+ model_name,
148
+ device_map="auto",
149
+ trust_remote_code=True,
150
+ )
151
+
152
+ if adapter_path and (isinstance(adapter_path, str) and adapter_path.strip()):
153
+ print(f"Loading LoRA adapter: {adapter_path}")
154
+ try:
155
+ self._model = PeftModel.from_pretrained(
156
+ base_model,
157
+ adapter_path,
158
+ torch_dtype=torch.float16
159
+ )
160
+ self._current_adapter = adapter_path
161
+ print(f"✅ Adapter loaded successfully")
162
+ except Exception as e:
163
+ print(f"⚠️ Could not load adapter: {e}")
164
+ print(" Using base model without adapter")
165
+ self._model = base_model
166
+ self._current_adapter = None
167
+ else:
168
+ self._model = base_model
169
+ self._current_adapter = None
170
+
171
+ self._model.eval()
172
+ print("Model ready")
173
+
174
+ def generate_reply(
175
+ self,
176
+ user_input: str,
177
+ history: List[Dict[str, str]] = None,
178
+ max_length: int = 200,
179
+ temperature: float = 0.7
180
+ ) -> str:
181
+ """Generate chatbot response."""
182
+ if self._model is None:
183
+ raise RuntimeError("Model not initialized")
184
+
185
+ if history is None:
186
+ history = []
187
+
188
+ messages = history + [{"role": "user", "content": user_input}]
189
+
190
+ try:
191
+ text = self._tokenizer.apply_chat_template(
192
+ messages,
193
+ tokenize=False,
194
+ add_generation_prompt=True
195
+ )
196
+ except:
197
+ text = user_input
198
+
199
+ inputs = self._tokenizer(
200
+ text,
201
+ return_tensors="pt",
202
+ truncation=True,
203
+ max_length=512
204
+ ).to(self._device)
205
+
206
+ with torch.no_grad():
207
+ outputs = self._model.generate(
208
+ **inputs,
209
+ max_new_tokens=max_length,
210
+ temperature=temperature,
211
+ do_sample=True,
212
+ top_p=0.9,
213
+ pad_token_id=self._tokenizer.eos_token_id
214
+ )
215
+
216
+ reply = self._tokenizer.decode(
217
+ outputs[0][inputs["input_ids"].shape[1]:],
218
+ skip_special_tokens=True
219
+ ).strip()
220
+
221
+ return reply
222
+
223
+
224
+ class FeedbackManager:
225
+ """Manages feedback storage and statistics."""
226
+ def __init__(self, feedback_file: str = "data/feedback.jsonl"):
227
+ self.feedback_file = Path(feedback_file)
228
+ self.feedback_file.parent.mkdir(parents=True, exist_ok=True)
229
+
230
+ def save_interaction(
231
+ self,
232
+ user_input: str,
233
+ model_reply: str,
234
+ user_correction: Optional[str] = None,
235
+ reason: Optional[str] = None
236
+ ):
237
+ """Save interaction to feedback file."""
238
+ record = {
239
+ "time": time.time(),
240
+ "user_input": user_input,
241
+ "model_reply": model_reply,
242
+ "user_correction": user_correction,
243
+ "accepted": user_correction is None,
244
+ "reason": reason,
245
+ }
246
+
247
+ with open(self.feedback_file, "a", encoding="utf-8") as f:
248
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
249
+
250
+ return record
251
+
252
+ def get_stats(self) -> Dict:
253
+ """Get feedback statistics."""
254
+ if not self.feedback_file.exists():
255
+ return {
256
+ "total_interactions": 0,
257
+ "corrections": 0,
258
+ "accepted": 0,
259
+ "correction_rate": 0.0
260
+ }
261
+
262
+ total = 0
263
+ corrections = 0
264
+ accepted = 0
265
+
266
+ with open(self.feedback_file, "r", encoding="utf-8") as f:
267
+ for line in f:
268
+ try:
269
+ record = json.loads(line)
270
+ total += 1
271
+ if record.get("accepted") is False:
272
+ corrections += 1
273
+ else:
274
+ accepted += 1
275
+ except:
276
+ pass
277
+
278
+ correction_rate = corrections / total if total > 0 else 0.0
279
+
280
+ return {
281
+ "total_interactions": total,
282
+ "corrections": corrections,
283
+ "accepted": accepted,
284
+ "correction_rate": correction_rate
285
+ }
286
+
287
+
288
+ model_manager = ModelManager()
289
+ feedback_manager = FeedbackManager(feedback_file="data/feedback.jsonl")
290
+
291
+
292
+ @app.on_event("startup")
293
+ async def startup_event():
294
+ """Initialize model on startup."""
295
+ print("Starting up...")
296
+
297
+ model_manager.initialize(
298
+ model_name="meta-llama/Llama-3.2-1B-Instruct",
299
+ adapter_path=None, # Update this after training: "username/adapter-v1"
300
+ use_4bit=True
301
+ )
302
+
303
+ print("Ready to serve!")
304
+
305
+
306
+ @app.get("/")
307
+ async def root():
308
+ """Root endpoint"""
309
+ return {
310
+ "message": "Personalized Chatbot API v2.0",
311
+ "version": "2.0.0",
312
+ "current_adapter": model_manager._current_adapter,
313
+ "endpoints": {
314
+ "chat": "POST /chat",
315
+ "feedback": "POST /feedback",
316
+ "stats": "GET /stats",
317
+ "download-feedback": "GET /download-feedback",
318
+ "correction-count": "GET /correction-count",
319
+ "clear-feedback": "POST /clear-feedback",
320
+ "reload-adapter": "POST /reload-adapter",
321
+ "health": "GET /health"
322
+ }
323
+ }
324
+
325
+
326
+ @app.get("/health")
327
+ async def health_check():
328
+ """Health check endpoint"""
329
+ return {
330
+ "status": "healthy",
331
+ "model_loaded": model_manager._model is not None,
332
+ "current_adapter": model_manager._current_adapter,
333
+ "device": str(model_manager._device)
334
+ }
335
+
336
+
337
+ @app.post("/chat", response_model=ChatResponse)
338
+ async def chat(request: ChatRequest):
339
+ """Generate chatbot response."""
340
+ try:
341
+ reply = model_manager.generate_reply(
342
+ user_input=request.message,
343
+ history=request.history,
344
+ max_length=request.max_length,
345
+ temperature=request.temperature
346
+ )
347
+
348
+ feedback_manager.save_interaction(
349
+ user_input=request.message,
350
+ model_reply=reply,
351
+ user_correction=None,
352
+ reason=None
353
+ )
354
+
355
+ return ChatResponse(
356
+ reply=reply,
357
+ timestamp=time.time()
358
+ )
359
+
360
+ except Exception as e:
361
+ raise HTTPException(status_code=500, detail=str(e))
362
+
363
+
364
+ @app.post("/feedback", response_model=FeedbackResponse)
365
+ async def submit_feedback(request: FeedbackRequest):
366
+ """Submit correction for a model response."""
367
+ try:
368
+ with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f:
369
+ lines = f.readlines()
370
+
371
+ found = False
372
+ for i in range(len(lines) - 1, -1, -1):
373
+ try:
374
+ record = json.loads(lines[i])
375
+ if (record["user_input"] == request.user_input and
376
+ record["model_reply"] == request.model_reply and
377
+ record["accepted"] is True):
378
+
379
+ record["user_correction"] = request.user_correction
380
+ record["accepted"] = False
381
+ record["reason"] = request.reason
382
+
383
+ lines[i] = json.dumps(record, ensure_ascii=False) + "\n"
384
+ found = True
385
+ break
386
+ except:
387
+ continue
388
+
389
+ if found:
390
+ with open(feedback_manager.feedback_file, "w", encoding="utf-8") as f:
391
+ f.writelines(lines)
392
+
393
+ return FeedbackResponse(
394
+ status="success",
395
+ message="Feedback recorded successfully"
396
+ )
397
+ else:
398
+ feedback_manager.save_interaction(
399
+ user_input=request.user_input,
400
+ model_reply=request.model_reply,
401
+ user_correction=request.user_correction,
402
+ reason=request.reason
403
+ )
404
+
405
+ return FeedbackResponse(
406
+ status="success",
407
+ message="Feedback recorded as new entry"
408
+ )
409
+
410
+ except Exception as e:
411
+ raise HTTPException(status_code=500, detail=str(e))
412
+
413
+
414
+ @app.get("/stats", response_model=StatsResponse)
415
+ async def get_stats():
416
+ """Get feedback statistics."""
417
+ stats = feedback_manager.get_stats()
418
+ return StatsResponse(**stats)
419
+
420
+
421
+ @app.get("/correction-count", response_model=CorrectionCountResponse)
422
+ async def get_correction_count():
423
+ """
424
+ Get count of corrections for training readiness monitoring.
425
+
426
+ Use this to check if you have enough corrections to train.
427
+ """
428
+ if not feedback_manager.feedback_file.exists():
429
+ return CorrectionCountResponse(
430
+ corrections=0,
431
+ total=0,
432
+ ready_to_train=False
433
+ )
434
+
435
+ total = 0
436
+ corrections = 0
437
+
438
+ with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f:
439
+ for line in f:
440
+ try:
441
+ record = json.loads(line)
442
+ total += 1
443
+ if record.get("accepted") is False:
444
+ corrections += 1
445
+ except:
446
+ pass
447
+
448
+ return CorrectionCountResponse(
449
+ corrections=corrections,
450
+ total=total,
451
+ ready_to_train=corrections >= 20
452
+ )
453
+
454
+
455
+ @app.get("/download-feedback", response_model=DownloadFeedbackResponse)
456
+ async def download_feedback():
457
+ """
458
+ Download feedback file for training.
459
+
460
+ Use this endpoint to download feedback from production backend
461
+ to your training notebook.
462
+
463
+ Example:
464
+ ```python
465
+ response = requests.get(f"{API_URL}/download-feedback")
466
+ feedback_data = response.json()
467
+
468
+ with open(HITL_FILE, 'w') as f:
469
+ f.write(feedback_data["content"])
470
+ ```
471
+ """
472
+ if not feedback_manager.feedback_file.exists():
473
+ return DownloadFeedbackResponse(
474
+ content="",
475
+ count=0
476
+ )
477
+
478
+ with open(feedback_manager.feedback_file, 'r', encoding='utf-8') as f:
479
+ content = f.read()
480
+ count = len(content.strip().split('\n')) if content.strip() else 0
481
+
482
+ return DownloadFeedbackResponse(
483
+ content=content,
484
+ count=count
485
+ )
486
+
487
+
488
+ @app.post("/clear-feedback")
489
+ async def clear_feedback():
490
+ """
491
+ Clear feedback file after training.
492
+
493
+ Call this after you've downloaded feedback and completed training
494
+ to start collecting fresh feedback for the next training cycle.
495
+
496
+ Example:
497
+ ```python
498
+ requests.post(f"{API_URL}/clear-feedback")
499
+ ```
500
+ """
501
+ try:
502
+ if feedback_manager.feedback_file.exists():
503
+ feedback_manager.feedback_file.unlink()
504
+ return {
505
+ "status": "success",
506
+ "message": "Feedback file cleared"
507
+ }
508
+ else:
509
+ return {
510
+ "status": "success",
511
+ "message": "Feedback file already empty"
512
+ }
513
+ except Exception as e:
514
+ raise HTTPException(status_code=500, detail=str(e))
515
+
516
+
517
+ @app.post("/reload-adapter")
518
+ async def reload_adapter(request: ReloadAdapterRequest):
519
+ """
520
+ Hot reload model with new adapter without restarting the Space.
521
+
522
+ This allows you to deploy new models without downtime.
523
+
524
+ Example:
525
+ ```python
526
+ # After training and pushing to HF Hub
527
+ requests.post(
528
+ f"{API_URL}/reload-adapter",
529
+ json={"adapter_path": "username/adapter-v2"}
530
+ )
531
+ ```
532
+ """
533
+ try:
534
+ model_manager.initialize(
535
+ model_name="meta-llama/Llama-3.2-1B-Instruct",
536
+ adapter_path=request.adapter_path,
537
+ use_4bit=True
538
+ )
539
+ return {
540
+ "status": "success",
541
+ "adapter": request.adapter_path,
542
+ "message": "Adapter reloaded successfully"
543
+ }
544
+ except Exception as e:
545
+ raise HTTPException(
546
+ status_code=500,
547
+ detail=f"Failed to reload adapter: {str(e)}"
548
+ )
549
+
550
+
551
+ if __name__ == "__main__":
552
+ import uvicorn
553
+
554
+ uvicorn.run(
555
+ "app:app",
556
+ host="0.0.0.0",
557
+ port=7860,
558
+ reload=True
559
+ )