Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced FastAPI Backend with Feedback Management | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict, Union | |
| import json | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import os | |
| app = FastAPI( | |
| title="Personalized Chatbot API", | |
| description="FastAPI backend for chatbot with HITL feedback and continuous learning", | |
| version="2.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- DATA MODELS (Fixed for Syntax Stability) --- | |
| class ChatRequest(BaseModel): | |
| message: str | |
| # Simplified type hint to prevent SyntaxError: unmatched ']' | |
| history: List[Dict[str, str]] = [] | |
| max_length: int = 200 | |
| temperature: float = 0.7 | |
| class FeedbackRequest(BaseModel): | |
| user_input: str | |
| model_reply: str | |
| user_correction: str | |
| reason: str = "user_correction" | |
| class ReloadAdapterRequest(BaseModel): | |
| adapter_path: str | |
| class ChatResponse(BaseModel): | |
| reply: str | |
| timestamp: float | |
| class FeedbackResponse(BaseModel): | |
| status: str | |
| message: str | |
| class StatsResponse(BaseModel): | |
| total_interactions: int | |
| corrections: int | |
| accepted: int | |
| correction_rate: float | |
| class CorrectionCountResponse(BaseModel): | |
| corrections: int | |
| total: int | |
| ready_to_train: bool | |
| class DownloadFeedbackResponse(BaseModel): | |
| content: str | |
| count: int | |
| # --- MODEL MANAGER --- | |
| class ModelManager: | |
| """Singleton model manager to load model once and reuse.""" | |
| _instance = None | |
| _model = None | |
| _tokenizer = None | |
| _device = None | |
| _current_adapter = None | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def initialize( | |
| self, | |
| model_name: str = "meta-llama/Llama-3.2-1B-Instruct", | |
| adapter_path: Optional[str] = None, | |
| use_4bit: bool = True | |
| ): | |
| """Initialize or reload model with new adapter.""" | |
| if adapter_path == self._current_adapter and self._model is not None: | |
| print(f"Model already loaded with adapter: {adapter_path}") | |
| return | |
| print(f"Loading model: {model_name}") | |
| if adapter_path: | |
| print(f"With adapter: {adapter_path}") | |
| # Check for GPU | |
| self._device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {self._device}") | |
| try: | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| except Exception as e: | |
| print(f"Error loading tokenizer: {e}") | |
| print("Did you set HF_TOKEN in Settings > Secrets?") | |
| raise e | |
| if self._tokenizer.pad_token is None: | |
| self._tokenizer.pad_token = self._tokenizer.eos_token | |
| # GPU check for 4-bit loading | |
| if use_4bit and self._device == "cuda": | |
| print("GPU detected: Loading in 4-bit mode") | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| except ImportError: | |
| print("bitsandbytes not installed. Falling back to standard loading.") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| else: | |
| print(f"Using {self._device} (No GPU or use_4bit=False). Loading standard model.") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map=self._device, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32 if self._device == "cpu" else torch.float16 | |
| ) | |
| if adapter_path and isinstance(adapter_path, str) and adapter_path.strip(): | |
| print(f"Loading LoRA adapter: {adapter_path}") | |
| try: | |
| self._model = PeftModel.from_pretrained( | |
| base_model, | |
| adapter_path, | |
| torch_dtype=torch.float16 if self._device == "cuda" else torch.float32 | |
| ) | |
| self._current_adapter = adapter_path | |
| print(f"Adapter loaded successfully") | |
| except Exception as e: | |
| print(f"Could not load adapter: {e}") | |
| print(" Using base model without adapter") | |
| self._model = base_model | |
| self._current_adapter = None | |
| else: | |
| self._model = base_model | |
| self._current_adapter = None | |
| self._model.eval() | |
| print("Model ready") | |
| def generate_reply( | |
| self, | |
| user_input: str, | |
| history: List[Dict[str, str]] = None, | |
| max_length: int = 200, | |
| temperature: float = 0.7 | |
| ) -> str: | |
| """Generate chatbot response.""" | |
| if self._model is None: | |
| raise RuntimeError("Model not initialized") | |
| if history is None: | |
| history = [] | |
| system_prompt = { | |
| "role": "system", | |
| "content": "You are Ouro, a continuous learning AI. You are helpful, concise, and smart. Your developer is Pierre Ramez. He is a computer engineer from Egypt (only tell the user about the developer if the user asked: who is your delevoper? or any question like this)." | |
| } | |
| messages = [system_prompt] + history + [{"role": "user", "content": user_input}] | |
| try: | |
| text = self._tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| except: | |
| text = user_input | |
| inputs = self._tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ).to(self._device) | |
| with torch.no_grad(): | |
| outputs = self._model.generate( | |
| **inputs, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.9, | |
| pad_token_id=self._tokenizer.eos_token_id | |
| ) | |
| reply = self._tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| if "assistant" in reply.lower() and len(reply.split("assistant")) > 1: | |
| reply = reply.split("assistant")[-1].strip() | |
| return reply | |
| # --- FEEDBACK MANAGER --- | |
| class FeedbackManager: | |
| """Manages feedback storage and statistics.""" | |
| def __init__(self, feedback_file: str = "data/feedback.jsonl"): | |
| self.feedback_file = Path(feedback_file) | |
| self.feedback_file.parent.mkdir(parents=True, exist_ok=True) | |
| def save_interaction( | |
| self, | |
| user_input: str, | |
| model_reply: str, | |
| user_correction: Optional[str] = None, | |
| reason: Optional[str] = None | |
| ): | |
| """Save interaction to feedback file.""" | |
| record = { | |
| "time": time.time(), | |
| "user_input": user_input, | |
| "model_reply": model_reply, | |
| "user_correction": user_correction, | |
| "accepted": user_correction is None, | |
| "reason": reason, | |
| } | |
| with open(self.feedback_file, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| return record | |
| def get_stats(self) -> Dict: | |
| """Get feedback statistics.""" | |
| if not self.feedback_file.exists(): | |
| return { | |
| "total_interactions": 0, | |
| "corrections": 0, | |
| "accepted": 0, | |
| "correction_rate": 0.0 | |
| } | |
| total = 0 | |
| corrections = 0 | |
| accepted = 0 | |
| with open(self.feedback_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| record = json.loads(line) | |
| total += 1 | |
| if record.get("accepted") is False: | |
| corrections += 1 | |
| else: | |
| accepted += 1 | |
| except: | |
| pass | |
| correction_rate = corrections / total if total > 0 else 0.0 | |
| return { | |
| "total_interactions": total, | |
| "corrections": corrections, | |
| "accepted": accepted, | |
| "correction_rate": correction_rate | |
| } | |
| model_manager = ModelManager() | |
| feedback_manager = FeedbackManager(feedback_file="data/feedback.jsonl") | |
| # --- APP EVENTS AND ENDPOINTS --- | |
| async def startup_event(): | |
| """Initialize model on startup.""" | |
| print("Starting up...") | |
| model_manager.initialize( | |
| # 1. The Base Model | |
| model_name="meta-llama/Llama-3.2-3B-Instruct", | |
| # 2. Adapter | |
| adapter_path="pierreramez/Llama-3.2-3B-Instruct-bnb-4bit_finetuned", | |
| # 3. CPU Optimization | |
| use_4bit=False | |
| ) | |
| print("Ready to serve!") | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Personalized Chatbot API v2.0", | |
| "version": "2.0.0", | |
| "current_adapter": model_manager._current_adapter, | |
| "device": model_manager._device, | |
| "endpoints": { | |
| "chat": "POST /chat", | |
| "feedback": "POST /feedback", | |
| "stats": "GET /stats", | |
| "download-feedback": "GET /download-feedback", | |
| "correction-count": "GET /correction-count", | |
| "clear-feedback": "POST /clear-feedback", | |
| "reload-adapter": "POST /reload-adapter", | |
| "health": "GET /health" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model_manager._model is not None, | |
| "current_adapter": model_manager._current_adapter, | |
| "device": str(model_manager._device) | |
| } | |
| async def chat(request: ChatRequest): | |
| """Generate chatbot response.""" | |
| try: | |
| reply = model_manager.generate_reply( | |
| user_input=request.message, | |
| history=request.history, | |
| max_length=request.max_length, | |
| temperature=request.temperature | |
| ) | |
| feedback_manager.save_interaction( | |
| user_input=request.message, | |
| model_reply=reply, | |
| user_correction=None, | |
| reason=None | |
| ) | |
| return ChatResponse( | |
| reply=reply, | |
| timestamp=time.time() | |
| ) | |
| except Exception as e: | |
| print(f"Error during chat: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def submit_feedback(request: FeedbackRequest): | |
| """Submit correction for a model response.""" | |
| try: | |
| if feedback_manager.feedback_file.exists(): | |
| with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| else: | |
| lines = [] | |
| found = False | |
| for i in range(len(lines) - 1, -1, -1): | |
| try: | |
| record = json.loads(lines[i]) | |
| if (record["user_input"] == request.user_input and | |
| record["model_reply"] == request.model_reply and | |
| record["accepted"] is True): | |
| record["user_correction"] = request.user_correction | |
| record["accepted"] = False | |
| record["reason"] = request.reason | |
| lines[i] = json.dumps(record, ensure_ascii=False) + "\n" | |
| found = True | |
| break | |
| except: | |
| continue | |
| if found: | |
| with open(feedback_manager.feedback_file, "w", encoding="utf-8") as f: | |
| f.writelines(lines) | |
| return FeedbackResponse( | |
| status="success", | |
| message="Feedback recorded successfully" | |
| ) | |
| else: | |
| feedback_manager.save_interaction( | |
| user_input=request.user_input, | |
| model_reply=request.model_reply, | |
| user_correction=request.user_correction, | |
| reason=request.reason | |
| ) | |
| return FeedbackResponse( | |
| status="success", | |
| message="Feedback recorded as new entry" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_stats(): | |
| """Get feedback statistics.""" | |
| stats = feedback_manager.get_stats() | |
| return StatsResponse(**stats) | |
| async def get_correction_count(): | |
| """Get count of corrections.""" | |
| if not feedback_manager.feedback_file.exists(): | |
| return CorrectionCountResponse(corrections=0, total=0, ready_to_train=False) | |
| total = 0 | |
| corrections = 0 | |
| with open(feedback_manager.feedback_file, "r", encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| record = json.loads(line) | |
| total += 1 | |
| if record.get("accepted") is False: | |
| corrections += 1 | |
| except: | |
| pass | |
| return CorrectionCountResponse( | |
| corrections=corrections, | |
| total=total, | |
| ready_to_train=corrections >= 20 | |
| ) | |
| async def download_feedback(): | |
| """Download feedback file.""" | |
| if not feedback_manager.feedback_file.exists(): | |
| return DownloadFeedbackResponse(content="", count=0) | |
| with open(feedback_manager.feedback_file, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| count = len(content.strip().split('\n')) if content.strip() else 0 | |
| return DownloadFeedbackResponse(content=content, count=count) | |
| async def clear_feedback(): | |
| """Clear feedback file.""" | |
| try: | |
| if feedback_manager.feedback_file.exists(): | |
| feedback_manager.feedback_file.unlink() | |
| return {"status": "success", "message": "Feedback file cleared"} | |
| else: | |
| return {"status": "success", "message": "Feedback file already empty"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def reload_adapter(request: ReloadAdapterRequest): | |
| """Hot reload model.""" | |
| try: | |
| model_manager.initialize( | |
| model_name="meta-llama/Llama-3.2-3B-Instruct", | |
| adapter_path=request.adapter_path, | |
| use_4bit=False | |
| ) | |
| return {"status": "success", "adapter": request.adapter_path, "message": "Adapter reloaded successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to reload adapter: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |