pierreramez commited on
Commit
33b76bd
·
verified ·
1 Parent(s): 6b940c7

Fixed app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -73
app.py CHANGED
@@ -5,7 +5,7 @@ Enhanced FastAPI Backend with Feedback Management
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
- from typing import List, Optional, Dict
9
  import json
10
  import time
11
  from pathlib import Path
@@ -28,52 +28,48 @@ app.add_middleware(
28
  allow_headers=["*"],
29
  )
30
 
 
31
 
32
  class ChatRequest(BaseModel):
33
  message: str
34
- history: Optional<List[Dict[str, str]]] = []
35
- max_length: Optional[int] = 200
36
- temperature: Optional[float] = 0.7
37
-
38
 
39
  class FeedbackRequest(BaseModel):
40
  user_input: str
41
  model_reply: str
42
  user_correction: str
43
- reason: Optional[str] = "user_correction"
44
-
45
 
46
  class ReloadAdapterRequest(BaseModel):
47
  adapter_path: str
48
 
49
-
50
  class ChatResponse(BaseModel):
51
  reply: str
52
  timestamp: float
53
 
54
-
55
  class FeedbackResponse(BaseModel):
56
  status: str
57
  message: str
58
 
59
-
60
  class StatsResponse(BaseModel):
61
  total_interactions: int
62
  corrections: int
63
  accepted: int
64
  correction_rate: float
65
 
66
-
67
  class CorrectionCountResponse(BaseModel):
68
  corrections: int
69
  total: int
70
  ready_to_train: bool
71
 
72
-
73
  class DownloadFeedbackResponse(BaseModel):
74
  content: str
75
  count: int
76
 
 
77
 
78
  class ModelManager:
79
  """Singleton model manager to load model once and reuse."""
@@ -121,7 +117,7 @@ class ModelManager:
121
  if self._tokenizer.pad_token is None:
122
  self._tokenizer.pad_token = self._tokenizer.eos_token
123
 
124
- # CRITICAL FIX: Only try 4-bit if we actually have a GPU
125
  if use_4bit and self._device == "cuda":
126
  print("🚀 GPU detected: Loading in 4-bit mode")
127
  try:
@@ -154,11 +150,10 @@ class ModelManager:
154
  model_name,
155
  device_map=self._device,
156
  trust_remote_code=True,
157
- # Use float32 for CPU stability
158
  torch_dtype=torch.float32 if self._device == "cpu" else torch.float16
159
  )
160
 
161
- if adapter_path and (isinstance(adapter_path, str) and adapter_path.strip()):
162
  print(f"Loading LoRA adapter: {adapter_path}")
163
  try:
164
  self._model = PeftModel.from_pretrained(
@@ -227,18 +222,17 @@ class ModelManager:
227
  skip_special_tokens=True
228
  ).strip()
229
 
230
- # Remove the system/user prompt if it leaked into response
231
  if "assistant" in reply.lower() and len(reply.split("assistant")) > 1:
232
  reply = reply.split("assistant")[-1].strip()
233
 
234
  return reply
235
 
 
236
 
237
  class FeedbackManager:
238
  """Manages feedback storage and statistics."""
239
  def __init__(self, feedback_file: str = "data/feedback.jsonl"):
240
  self.feedback_file = Path(feedback_file)
241
- # Ensure directory exists (Handled by Dockerfile too, but good safety)
242
  self.feedback_file.parent.mkdir(parents=True, exist_ok=True)
243
 
244
  def save_interaction(
@@ -298,10 +292,10 @@ class FeedbackManager:
298
  "correction_rate": correction_rate
299
  }
300
 
301
-
302
  model_manager = ModelManager()
303
  feedback_manager = FeedbackManager(feedback_file="data/feedback.jsonl")
304
 
 
305
 
306
  @app.on_event("startup")
307
  async def startup_event():
@@ -310,20 +304,17 @@ async def startup_event():
310
 
311
  model_manager.initialize(
312
  # 1. The Base Model (The heavy lifter)
313
- # We use the official Llama 3.2 3B Instruct as the foundation
314
  model_name="meta-llama/Llama-3.2-3B-Instruct",
315
 
316
- # 2. Adapter (The personalization)
317
- adapter_path="pierreramez/Llama-3.2-3B-Instruct-bnb-4bit_finetuned",
318
 
319
- # 3. CPU Optimization
320
- # Must be False for the free CPU tier
321
  use_4bit=False
322
  )
323
 
324
  print("Ready to serve!")
325
 
326
-
327
  @app.get("/")
328
  async def root():
329
  """Root endpoint"""
@@ -344,7 +335,6 @@ async def root():
344
  }
345
  }
346
 
347
-
348
  @app.get("/health")
349
  async def health_check():
350
  """Health check endpoint"""
@@ -355,7 +345,6 @@ async def health_check():
355
  "device": str(model_manager._device)
356
  }
357
 
358
-
359
  @app.post("/chat", response_model=ChatResponse)
360
  async def chat(request: ChatRequest):
361
  """Generate chatbot response."""
@@ -383,12 +372,10 @@ async def chat(request: ChatRequest):
383
  print(f"Error during chat: {e}")
384
  raise HTTPException(status_code=500, detail=str(e))
385
 
386
-
387
  @app.post("/feedback", response_model=FeedbackResponse)
388
  async def submit_feedback(request: FeedbackRequest):
389
  """Submit correction for a model response."""
390
  try:
391
- # Optimistic feedback update: try to find existing entry
392
  if feedback_manager.feedback_file.exists():
393
  with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f:
394
  lines = f.readlines()
@@ -396,7 +383,6 @@ async def submit_feedback(request: FeedbackRequest):
396
  lines = []
397
 
398
  found = False
399
- # Search backwards to find the most recent matching interaction
400
  for i in range(len(lines) - 1, -1, -1):
401
  try:
402
  record = json.loads(lines[i])
@@ -423,7 +409,6 @@ async def submit_feedback(request: FeedbackRequest):
423
  message="Feedback recorded successfully"
424
  )
425
  else:
426
- # If not found (e.g., app restarted), just append new record
427
  feedback_manager.save_interaction(
428
  user_input=request.user_input,
429
  model_reply=request.model_reply,
@@ -439,27 +424,20 @@ async def submit_feedback(request: FeedbackRequest):
439
  except Exception as e:
440
  raise HTTPException(status_code=500, detail=str(e))
441
 
442
-
443
  @app.get("/stats", response_model=StatsResponse)
444
  async def get_stats():
445
  """Get feedback statistics."""
446
  stats = feedback_manager.get_stats()
447
  return StatsResponse(**stats)
448
 
449
-
450
  @app.get("/correction-count", response_model=CorrectionCountResponse)
451
  async def get_correction_count():
452
- """Get count of corrections for training readiness monitoring."""
453
  if not feedback_manager.feedback_file.exists():
454
- return CorrectionCountResponse(
455
- corrections=0,
456
- total=0,
457
- ready_to_train=False
458
- )
459
 
460
  total = 0
461
  corrections = 0
462
-
463
  with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f:
464
  for line in f:
465
  try:
@@ -469,72 +447,48 @@ async def get_correction_count():
469
  corrections += 1
470
  except:
471
  pass
472
-
473
  return CorrectionCountResponse(
474
  corrections=corrections,
475
  total=total,
476
  ready_to_train=corrections >= 20
477
  )
478
 
479
-
480
  @app.get("/download-feedback", response_model=DownloadFeedbackResponse)
481
  async def download_feedback():
482
- """Download feedback file for training."""
483
  if not feedback_manager.feedback_file.exists():
484
- return DownloadFeedbackResponse(
485
- content="",
486
- count=0
487
- )
488
 
489
  with open(feedback_manager.feedback_file, 'r', encoding='utf-8') as f:
490
  content = f.read()
491
  count = len(content.strip().split('\n')) if content.strip() else 0
492
 
493
- return DownloadFeedbackResponse(
494
- content=content,
495
- count=count
496
- )
497
-
498
 
499
  @app.post("/clear-feedback")
500
  async def clear_feedback():
501
- """Clear feedback file after training."""
502
  try:
503
  if feedback_manager.feedback_file.exists():
504
  feedback_manager.feedback_file.unlink()
505
- return {
506
- "status": "success",
507
- "message": "Feedback file cleared"
508
- }
509
  else:
510
- return {
511
- "status": "success",
512
- "message": "Feedback file already empty"
513
- }
514
  except Exception as e:
515
  raise HTTPException(status_code=500, detail=str(e))
516
 
517
-
518
  @app.post("/reload-adapter")
519
  async def reload_adapter(request: ReloadAdapterRequest):
520
- """Hot reload model with new adapter."""
521
  try:
522
  model_manager.initialize(
523
  model_name="meta-llama/Llama-3.2-1B-Instruct",
524
  adapter_path=request.adapter_path,
525
- use_4bit=True
526
  )
527
- return {
528
- "status": "success",
529
- "adapter": request.adapter_path,
530
- "message": "Adapter reloaded successfully"
531
- }
532
  except Exception as e:
533
- raise HTTPException(
534
- status_code=500,
535
- detail=f"Failed to reload adapter: {str(e)}"
536
- )
537
-
538
 
539
  if __name__ == "__main__":
540
  import uvicorn
 
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
+ from typing import Optional, List, Dict, Union
9
  import json
10
  import time
11
  from pathlib import Path
 
28
  allow_headers=["*"],
29
  )
30
 
31
+ # --- DATA MODELS (Fixed for Syntax Stability) ---
32
 
33
  class ChatRequest(BaseModel):
34
  message: str
35
+ # Simplified type hint to prevent SyntaxError: unmatched ']'
36
+ history: List[Dict[str, str]] = []
37
+ max_length: int = 200
38
+ temperature: float = 0.7
39
 
40
  class FeedbackRequest(BaseModel):
41
  user_input: str
42
  model_reply: str
43
  user_correction: str
44
+ reason: str = "user_correction"
 
45
 
46
  class ReloadAdapterRequest(BaseModel):
47
  adapter_path: str
48
 
 
49
  class ChatResponse(BaseModel):
50
  reply: str
51
  timestamp: float
52
 
 
53
  class FeedbackResponse(BaseModel):
54
  status: str
55
  message: str
56
 
 
57
  class StatsResponse(BaseModel):
58
  total_interactions: int
59
  corrections: int
60
  accepted: int
61
  correction_rate: float
62
 
 
63
  class CorrectionCountResponse(BaseModel):
64
  corrections: int
65
  total: int
66
  ready_to_train: bool
67
 
 
68
  class DownloadFeedbackResponse(BaseModel):
69
  content: str
70
  count: int
71
 
72
+ # --- MODEL MANAGER ---
73
 
74
  class ModelManager:
75
  """Singleton model manager to load model once and reuse."""
 
117
  if self._tokenizer.pad_token is None:
118
  self._tokenizer.pad_token = self._tokenizer.eos_token
119
 
120
+ # GPU check for 4-bit loading
121
  if use_4bit and self._device == "cuda":
122
  print("🚀 GPU detected: Loading in 4-bit mode")
123
  try:
 
150
  model_name,
151
  device_map=self._device,
152
  trust_remote_code=True,
 
153
  torch_dtype=torch.float32 if self._device == "cpu" else torch.float16
154
  )
155
 
156
+ if adapter_path and isinstance(adapter_path, str) and adapter_path.strip():
157
  print(f"Loading LoRA adapter: {adapter_path}")
158
  try:
159
  self._model = PeftModel.from_pretrained(
 
222
  skip_special_tokens=True
223
  ).strip()
224
 
 
225
  if "assistant" in reply.lower() and len(reply.split("assistant")) > 1:
226
  reply = reply.split("assistant")[-1].strip()
227
 
228
  return reply
229
 
230
+ # --- FEEDBACK MANAGER ---
231
 
232
  class FeedbackManager:
233
  """Manages feedback storage and statistics."""
234
  def __init__(self, feedback_file: str = "data/feedback.jsonl"):
235
  self.feedback_file = Path(feedback_file)
 
236
  self.feedback_file.parent.mkdir(parents=True, exist_ok=True)
237
 
238
  def save_interaction(
 
292
  "correction_rate": correction_rate
293
  }
294
 
 
295
  model_manager = ModelManager()
296
  feedback_manager = FeedbackManager(feedback_file="data/feedback.jsonl")
297
 
298
+ # --- APP EVENTS AND ENDPOINTS ---
299
 
300
  @app.on_event("startup")
301
  async def startup_event():
 
304
 
305
  model_manager.initialize(
306
  # 1. The Base Model (The heavy lifter)
 
307
  model_name="meta-llama/Llama-3.2-3B-Instruct",
308
 
309
+ # 2. Adapter (The personalization) - YOUR SPECIFIC REPO
310
+ adapter_path="pierreramez/Llama-3.2-3B-Instruct-bnb-4bit_finetuned",
311
 
312
+ # 3. CPU Optimization (Must be False for free tier)
 
313
  use_4bit=False
314
  )
315
 
316
  print("Ready to serve!")
317
 
 
318
  @app.get("/")
319
  async def root():
320
  """Root endpoint"""
 
335
  }
336
  }
337
 
 
338
  @app.get("/health")
339
  async def health_check():
340
  """Health check endpoint"""
 
345
  "device": str(model_manager._device)
346
  }
347
 
 
348
  @app.post("/chat", response_model=ChatResponse)
349
  async def chat(request: ChatRequest):
350
  """Generate chatbot response."""
 
372
  print(f"Error during chat: {e}")
373
  raise HTTPException(status_code=500, detail=str(e))
374
 
 
375
  @app.post("/feedback", response_model=FeedbackResponse)
376
  async def submit_feedback(request: FeedbackRequest):
377
  """Submit correction for a model response."""
378
  try:
 
379
  if feedback_manager.feedback_file.exists():
380
  with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f:
381
  lines = f.readlines()
 
383
  lines = []
384
 
385
  found = False
 
386
  for i in range(len(lines) - 1, -1, -1):
387
  try:
388
  record = json.loads(lines[i])
 
409
  message="Feedback recorded successfully"
410
  )
411
  else:
 
412
  feedback_manager.save_interaction(
413
  user_input=request.user_input,
414
  model_reply=request.model_reply,
 
424
  except Exception as e:
425
  raise HTTPException(status_code=500, detail=str(e))
426
 
 
427
  @app.get("/stats", response_model=StatsResponse)
428
  async def get_stats():
429
  """Get feedback statistics."""
430
  stats = feedback_manager.get_stats()
431
  return StatsResponse(**stats)
432
 
 
433
  @app.get("/correction-count", response_model=CorrectionCountResponse)
434
  async def get_correction_count():
435
+ """Get count of corrections."""
436
  if not feedback_manager.feedback_file.exists():
437
+ return CorrectionCountResponse(corrections=0, total=0, ready_to_train=False)
 
 
 
 
438
 
439
  total = 0
440
  corrections = 0
 
441
  with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f:
442
  for line in f:
443
  try:
 
447
  corrections += 1
448
  except:
449
  pass
 
450
  return CorrectionCountResponse(
451
  corrections=corrections,
452
  total=total,
453
  ready_to_train=corrections >= 20
454
  )
455
 
 
456
  @app.get("/download-feedback", response_model=DownloadFeedbackResponse)
457
  async def download_feedback():
458
+ """Download feedback file."""
459
  if not feedback_manager.feedback_file.exists():
460
+ return DownloadFeedbackResponse(content="", count=0)
 
 
 
461
 
462
  with open(feedback_manager.feedback_file, 'r', encoding='utf-8') as f:
463
  content = f.read()
464
  count = len(content.strip().split('\n')) if content.strip() else 0
465
 
466
+ return DownloadFeedbackResponse(content=content, count=count)
 
 
 
 
467
 
468
  @app.post("/clear-feedback")
469
  async def clear_feedback():
470
+ """Clear feedback file."""
471
  try:
472
  if feedback_manager.feedback_file.exists():
473
  feedback_manager.feedback_file.unlink()
474
+ return {"status": "success", "message": "Feedback file cleared"}
 
 
 
475
  else:
476
+ return {"status": "success", "message": "Feedback file already empty"}
 
 
 
477
  except Exception as e:
478
  raise HTTPException(status_code=500, detail=str(e))
479
 
 
480
  @app.post("/reload-adapter")
481
  async def reload_adapter(request: ReloadAdapterRequest):
482
+ """Hot reload model."""
483
  try:
484
  model_manager.initialize(
485
  model_name="meta-llama/Llama-3.2-1B-Instruct",
486
  adapter_path=request.adapter_path,
487
+ use_4bit=False
488
  )
489
+ return {"status": "success", "adapter": request.adapter_path, "message": "Adapter reloaded successfully"}
 
 
 
 
490
  except Exception as e:
491
+ raise HTTPException(status_code=500, detail=f"Failed to reload adapter: {str(e)}")
 
 
 
 
492
 
493
  if __name__ == "__main__":
494
  import uvicorn