File size: 3,559 Bytes
a4c8c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcbbfd9
a4c8c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from pathlib import Path
from typing import List, Optional

app = FastAPI(title="DNAI Humour Chatbot API", version="1.1")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variables
model = None
tokenizer = None
MODEL_NAME = "DarkNeuronAI/dnai-humour-0.5B-instruct"

@app.on_event("startup")
async def load_model():
    global model, tokenizer
    try:
        print(f"๐Ÿ”„ Loading {MODEL_NAME} on CPU...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        # Low CPU memory usage logic
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float32, 
            device_map="cpu",
            low_cpu_mem_usage=True
        )
        model.eval()
        print("โœ… Model loaded on CPU successfully!")
    except Exception as e:
        print(f"โŒ Error loading model: {str(e)}")
        raise

class Message(BaseModel):
    role: str
    content: str

# Updated Request Model to accept Settings
class ChatRequest(BaseModel):
    messages: List[Message]
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.9
    max_tokens: Optional[int] = 256
    system_prompt: Optional[str] = "You are DNAI, a helpful and humorous AI assistant."

def format_chat_prompt(messages: List[Message], system_prompt: str) -> str:
    # Adding System Prompt to the beginning
    formatted = f"System: {system_prompt}\n"
    for msg in messages:
        if msg.role == "user":
            formatted += f"User: {msg.content}\n"
        elif msg.role == "assistant":
            formatted += f"Assistant: {msg.content}\n"
    formatted += "Assistant:"
    return formatted

@app.get("/", response_class=HTMLResponse)
async def root():
    html_path = Path(__file__).parent / "index.html"
    if html_path.exists():
        with open(html_path, 'r', encoding='utf-8') as f:
            return HTMLResponse(content=f.read(), status_code=200)
    return "<h1>Error: index.html not found</h1>"

@app.post("/api/chat")
async def chat(request: ChatRequest):
    if model is None:
        raise HTTPException(status_code=503, detail="Model loading")
    
    try:
        # Pass system prompt explicitly
        prompt = format_chat_prompt(request.messages, request.system_prompt)
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=request.max_tokens,
                temperature=request.temperature,
                top_p=request.top_p,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Robust extraction
        response = generated_text[len(prompt):].strip()
        if "User:" in response:
            response = response.split("User:")[0].strip()
            
        return {"response": response}
        
    except Exception as e:
        print(f"Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)