DarkNeuron-AI commited on
Commit
d3d8567
ยท
verified ยท
1 Parent(s): d971cfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -44
app.py CHANGED
@@ -1,51 +1,53 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import HTMLResponse
4
  from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
  from pathlib import Path
8
  from typing import List, Optional
9
 
10
- app = FastAPI(title="DNAI Humour Chatbot API", version="1.1")
11
 
 
12
  app.add_middleware(
13
  CORSMiddleware,
14
- allow_origins=["*"],
15
  allow_credentials=True,
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
- # Global variables
21
  model = None
22
  tokenizer = None
23
  MODEL_NAME = "DarkNeuronAI/dnai-humour-0.5B-instruct"
24
 
 
25
  @app.on_event("startup")
26
  async def load_model():
27
  global model, tokenizer
28
- try:
29
- print(f"๐Ÿ”„ Loading {MODEL_NAME} on CPU...")
30
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
- # Low CPU memory usage logic
32
- model = AutoModelForCausalLM.from_pretrained(
33
- MODEL_NAME,
34
- torch_dtype=torch.float32,
35
- device_map="cpu",
36
- low_cpu_mem_usage=True
37
- )
38
- model.eval()
39
- print("โœ… Model loaded on CPU successfully!")
40
- except Exception as e:
41
- print(f"โŒ Error loading model: {str(e)}")
42
- raise
43
 
 
44
  class Message(BaseModel):
45
  role: str
46
  content: str
47
 
48
- # Updated Request Model to accept Settings
49
  class ChatRequest(BaseModel):
50
  messages: List[Message]
51
  temperature: Optional[float] = 0.7
@@ -53,36 +55,59 @@ class ChatRequest(BaseModel):
53
  max_tokens: Optional[int] = 256
54
  system_prompt: Optional[str] = "You are DNAI, a helpful and humorous AI assistant."
55
 
56
- def format_chat_prompt(messages: List[Message], system_prompt: str) -> str:
57
- # Adding System Prompt to the beginning
58
- formatted = f"System: {system_prompt}\n"
 
 
 
59
  for msg in messages:
60
  if msg.role == "user":
61
- formatted += f"User: {msg.content}\n"
62
  elif msg.role == "assistant":
63
- formatted += f"Assistant: {msg.content}\n"
64
- formatted += "Assistant:"
65
- return formatted
 
 
 
 
 
 
 
 
66
 
 
67
  @app.get("/", response_class=HTMLResponse)
68
  async def root():
69
  html_path = Path(__file__).parent / "index.html"
 
70
  if html_path.exists():
71
- with open(html_path, 'r', encoding='utf-8') as f:
72
- return HTMLResponse(content=f.read(), status_code=200)
73
- return "<h1>Error: index.html not found</h1>"
74
 
 
 
75
  @app.post("/api/chat")
76
  async def chat(request: ChatRequest):
 
77
  if model is None:
78
- raise HTTPException(status_code=503, detail="Model loading")
79
-
80
  try:
81
- # Pass system prompt explicitly
82
- prompt = format_chat_prompt(request.messages, request.system_prompt)
83
-
84
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
85
-
 
 
 
 
 
 
 
86
  with torch.no_grad():
87
  outputs = model.generate(
88
  **inputs,
@@ -90,21 +115,31 @@ async def chat(request: ChatRequest):
90
  temperature=request.temperature,
91
  top_p=request.top_p,
92
  do_sample=True,
 
93
  pad_token_id=tokenizer.eos_token_id
94
  )
95
-
96
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
- # Robust extraction
 
 
 
98
  response = generated_text[len(prompt):].strip()
 
99
  if "User:" in response:
100
  response = response.split("User:")[0].strip()
101
-
102
- return {"response": response}
103
-
 
 
 
104
  except Exception as e:
105
- print(f"Error: {e}")
106
  raise HTTPException(status_code=500, detail=str(e))
107
 
 
 
108
  if __name__ == "__main__":
109
  import uvicorn
110
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import HTMLResponse, JSONResponse
4
  from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
  from pathlib import Path
8
  from typing import List, Optional
9
 
10
+ app = FastAPI(title="DNAI Humour Chatbot API", version="1.2")
11
 
12
+ # ---------------- CORS ----------------
13
  app.add_middleware(
14
  CORSMiddleware,
15
+ allow_origins=["*"], # change later to your domain
16
  allow_credentials=True,
17
  allow_methods=["*"],
18
  allow_headers=["*"],
19
  )
20
 
21
+ # ---------------- MODEL ----------------
22
  model = None
23
  tokenizer = None
24
  MODEL_NAME = "DarkNeuronAI/dnai-humour-0.5B-instruct"
25
 
26
+
27
  @app.on_event("startup")
28
  async def load_model():
29
  global model, tokenizer
30
+ print("๐Ÿ”„ Loading model...")
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
+
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ MODEL_NAME,
36
+ torch_dtype=torch.float32,
37
+ device_map="cpu",
38
+ low_cpu_mem_usage=True
39
+ )
40
+
41
+ model.eval()
42
+ print("โœ… Model Ready")
43
+
 
44
 
45
+ # ---------------- REQUEST MODELS ----------------
46
  class Message(BaseModel):
47
  role: str
48
  content: str
49
 
50
+
51
  class ChatRequest(BaseModel):
52
  messages: List[Message]
53
  temperature: Optional[float] = 0.7
 
55
  max_tokens: Optional[int] = 256
56
  system_prompt: Optional[str] = "You are DNAI, a helpful and humorous AI assistant."
57
 
58
+
59
+ # ---------------- PROMPT FORMAT ----------------
60
+ def format_chat_prompt(messages: List[Message], system_prompt: str):
61
+
62
+ prompt = f"System: {system_prompt}\n"
63
+
64
  for msg in messages:
65
  if msg.role == "user":
66
+ prompt += f"User: {msg.content}\n"
67
  elif msg.role == "assistant":
68
+ prompt += f"Assistant: {msg.content}\n"
69
+
70
+ prompt += "Assistant:"
71
+ return prompt
72
+
73
+
74
+ # ---------------- HEALTH CHECK ----------------
75
+ @app.get("/health")
76
+ async def health():
77
+ return {"status": "ok", "model_loaded": model is not None}
78
+
79
 
80
+ # ---------------- SERVE WEBSITE ----------------
81
  @app.get("/", response_class=HTMLResponse)
82
  async def root():
83
  html_path = Path(__file__).parent / "index.html"
84
+
85
  if html_path.exists():
86
+ return HTMLResponse(html_path.read_text(encoding="utf-8"))
87
+
88
+ return "<h1>index.html not found</h1>"
89
 
90
+
91
+ # ---------------- CHAT API ----------------
92
  @app.post("/api/chat")
93
  async def chat(request: ChatRequest):
94
+
95
  if model is None:
96
+ raise HTTPException(status_code=503, detail="Model still loading")
97
+
98
  try:
99
+ prompt = format_chat_prompt(
100
+ request.messages,
101
+ request.system_prompt
102
+ )
103
+
104
+ inputs = tokenizer(
105
+ prompt,
106
+ return_tensors="pt",
107
+ truncation=True,
108
+ max_length=1024
109
+ )
110
+
111
  with torch.no_grad():
112
  outputs = model.generate(
113
  **inputs,
 
115
  temperature=request.temperature,
116
  top_p=request.top_p,
117
  do_sample=True,
118
+ repetition_penalty=1.1,
119
  pad_token_id=tokenizer.eos_token_id
120
  )
121
+
122
+ generated_text = tokenizer.decode(
123
+ outputs[0],
124
+ skip_special_tokens=True
125
+ )
126
+
127
  response = generated_text[len(prompt):].strip()
128
+
129
  if "User:" in response:
130
  response = response.split("User:")[0].strip()
131
+
132
+ return JSONResponse({
133
+ "response": response,
134
+ "status": "success"
135
+ })
136
+
137
  except Exception as e:
138
+ print("โŒ Generation error:", e)
139
  raise HTTPException(status_code=500, detail=str(e))
140
 
141
+
142
+ # ---------------- LOCAL RUN ----------------
143
  if __name__ == "__main__":
144
  import uvicorn
145
  uvicorn.run(app, host="0.0.0.0", port=7860)