from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch from pathlib import Path from typing import List, Optional app = FastAPI(title="DNAI Humour Chatbot API", version="1.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables model = None tokenizer = None MODEL_NAME = "DarkNeuronAI/dnai-humour-0.5B-instruct" @app.on_event("startup") async def load_model(): global model, tokenizer try: print(f"🔄 Loading {MODEL_NAME} on CPU...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Low CPU memory usage logic model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map="cpu", low_cpu_mem_usage=True ) model.eval() print("✅ Model loaded on CPU successfully!") except Exception as e: print(f"❌ Error loading model: {str(e)}") raise class Message(BaseModel): role: str content: str # Updated Request Model to accept Settings class ChatRequest(BaseModel): messages: List[Message] temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 max_tokens: Optional[int] = 256 system_prompt: Optional[str] = "You are DNAI, a helpful and humorous AI assistant." def format_chat_prompt(messages: List[Message], system_prompt: str) -> str: # Adding System Prompt to the beginning formatted = f"System: {system_prompt}\n" for msg in messages: if msg.role == "user": formatted += f"User: {msg.content}\n" elif msg.role == "assistant": formatted += f"Assistant: {msg.content}\n" formatted += "Assistant:" return formatted @app.get("/", response_class=HTMLResponse) async def root(): html_path = Path(__file__).parent / "index.html" if html_path.exists(): with open(html_path, 'r', encoding='utf-8') as f: return HTMLResponse(content=f.read(), status_code=200) return "

Error: index.html not found

" @app.post("/api/chat") async def chat(request: ChatRequest): if model is None: raise HTTPException(status_code=503, detail="Model loading") try: # Pass system prompt explicitly prompt = format_chat_prompt(request.messages, request.system_prompt) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Robust extraction response = generated_text[len(prompt):].strip() if "User:" in response: response = response.split("User:")[0].strip() return {"response": response} except Exception as e: print(f"Error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)