Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import torch
|
|
|
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
|
| 4 |
from peft import PeftModel
|
| 5 |
from fastapi import FastAPI, HTTPException
|
|
@@ -7,6 +8,7 @@ from pydantic import BaseModel
|
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
from typing import List
|
| 9 |
import gc
|
|
|
|
| 10 |
|
| 11 |
# -----------------------------
|
| 12 |
# CONFIG
|
|
@@ -71,8 +73,8 @@ try:
|
|
| 71 |
token=HF_TOKEN,
|
| 72 |
low_cpu_mem_usage=True,
|
| 73 |
trust_remote_code=True,
|
| 74 |
-
offload_folder=offload_dir,
|
| 75 |
-
offload_state_dict=True,
|
| 76 |
)
|
| 77 |
|
| 78 |
print("β
Base model loaded in 4-bit!")
|
|
@@ -83,7 +85,7 @@ try:
|
|
| 83 |
base_model,
|
| 84 |
LORA_REPO,
|
| 85 |
token=HF_TOKEN,
|
| 86 |
-
offload_folder=offload_dir,
|
| 87 |
)
|
| 88 |
|
| 89 |
print("β
LoRA adapter loaded!")
|
|
@@ -170,11 +172,13 @@ def detect_mood(text: str) -> str:
|
|
| 170 |
print(f"Mood detection error: {e}")
|
| 171 |
return "neutral"
|
| 172 |
|
|
|
|
|
|
|
| 173 |
def generate_shinchan_response(user_input: str, mood: str) -> str:
|
| 174 |
"""Generate Shinchan's response based on user input and mood"""
|
| 175 |
try:
|
| 176 |
-
#
|
| 177 |
-
context = "\n".join(memory[-
|
| 178 |
|
| 179 |
prompt = (
|
| 180 |
f"<s>[INST] {SYS_PROMPT}\n"
|
|
@@ -187,7 +191,7 @@ def generate_shinchan_response(user_input: str, mood: str) -> str:
|
|
| 187 |
prompt,
|
| 188 |
return_tensors="pt",
|
| 189 |
truncation=True,
|
| 190 |
-
max_length=
|
| 191 |
padding=True
|
| 192 |
)
|
| 193 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
|
@@ -197,12 +201,13 @@ def generate_shinchan_response(user_input: str, mood: str) -> str:
|
|
| 197 |
with torch.no_grad():
|
| 198 |
outputs = model.generate(
|
| 199 |
**inputs,
|
| 200 |
-
max_new_tokens=
|
| 201 |
temperature=temperature,
|
| 202 |
top_p=0.9,
|
| 203 |
top_k=50,
|
| 204 |
repetition_penalty=1.15,
|
| 205 |
do_sample=True,
|
|
|
|
| 206 |
pad_token_id=tokenizer.eos_token_id,
|
| 207 |
eos_token_id=tokenizer.eos_token_id,
|
| 208 |
)
|
|
@@ -227,6 +232,8 @@ def generate_shinchan_response(user_input: str, mood: str) -> str:
|
|
| 227 |
|
| 228 |
except Exception as e:
|
| 229 |
print(f"Generation error: {e}")
|
|
|
|
|
|
|
| 230 |
return f"Arrey yaar! Something went wrong π€ Error: {str(e)[:100]}"
|
| 231 |
|
| 232 |
# -----------------------------
|
|
@@ -250,34 +257,55 @@ async def health():
|
|
| 250 |
"lora": LORA_REPO,
|
| 251 |
"device": str(next(model.parameters()).device),
|
| 252 |
"memory_entries": len(memory),
|
| 253 |
-
"mood_detection": mood_pipe is not None
|
|
|
|
| 254 |
}
|
| 255 |
|
| 256 |
@app.post("/chat", response_model=MessageResponse)
|
| 257 |
async def chat_endpoint(req: MessageRequest):
|
| 258 |
-
"""Main chat endpoint"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
try:
|
| 260 |
if not req.user_input or not req.user_input.strip():
|
| 261 |
raise HTTPException(status_code=400, detail="Empty message")
|
| 262 |
|
| 263 |
user_text = req.user_input.strip()
|
| 264 |
|
| 265 |
-
# Detect mood
|
|
|
|
|
|
|
| 266 |
mood = detect_mood(user_text)
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
# Generate response
|
|
|
|
|
|
|
| 269 |
response = generate_shinchan_response(user_text, mood)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
return MessageResponse(
|
| 272 |
response=response,
|
| 273 |
mood=mood,
|
| 274 |
-
memory=memory[-10:]
|
| 275 |
)
|
| 276 |
|
| 277 |
except HTTPException:
|
| 278 |
raise
|
| 279 |
except Exception as e:
|
| 280 |
-
|
|
|
|
| 281 |
import traceback
|
| 282 |
traceback.print_exc()
|
| 283 |
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
|
|
@@ -286,7 +314,9 @@ async def chat_endpoint(req: MessageRequest):
|
|
| 286 |
async def reset_memory():
|
| 287 |
"""Reset conversation memory"""
|
| 288 |
global memory
|
|
|
|
| 289 |
memory = []
|
|
|
|
| 290 |
return {"status": "Memory cleared", "memory_size": 0}
|
| 291 |
|
| 292 |
# -----------------------------
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
+
import time
|
| 4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
|
| 5 |
from peft import PeftModel
|
| 6 |
from fastapi import FastAPI, HTTPException
|
|
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from typing import List
|
| 10 |
import gc
|
| 11 |
+
import spaces # β
Import for Zero GPU
|
| 12 |
|
| 13 |
# -----------------------------
|
| 14 |
# CONFIG
|
|
|
|
| 73 |
token=HF_TOKEN,
|
| 74 |
low_cpu_mem_usage=True,
|
| 75 |
trust_remote_code=True,
|
| 76 |
+
offload_folder=offload_dir,
|
| 77 |
+
offload_state_dict=True,
|
| 78 |
)
|
| 79 |
|
| 80 |
print("β
Base model loaded in 4-bit!")
|
|
|
|
| 85 |
base_model,
|
| 86 |
LORA_REPO,
|
| 87 |
token=HF_TOKEN,
|
| 88 |
+
offload_folder=offload_dir,
|
| 89 |
)
|
| 90 |
|
| 91 |
print("β
LoRA adapter loaded!")
|
|
|
|
| 172 |
print(f"Mood detection error: {e}")
|
| 173 |
return "neutral"
|
| 174 |
|
| 175 |
+
# β
ZERO GPU DECORATOR - This gets you FREE GPU!
|
| 176 |
+
@spaces.GPU(duration=60) # Max 60 seconds per request
|
| 177 |
def generate_shinchan_response(user_input: str, mood: str) -> str:
|
| 178 |
"""Generate Shinchan's response based on user input and mood"""
|
| 179 |
try:
|
| 180 |
+
# Use last 2 exchanges (4 entries) for context
|
| 181 |
+
context = "\n".join(memory[-4:]) if memory else ""
|
| 182 |
|
| 183 |
prompt = (
|
| 184 |
f"<s>[INST] {SYS_PROMPT}\n"
|
|
|
|
| 191 |
prompt,
|
| 192 |
return_tensors="pt",
|
| 193 |
truncation=True,
|
| 194 |
+
max_length=384,
|
| 195 |
padding=True
|
| 196 |
)
|
| 197 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
|
|
|
| 201 |
with torch.no_grad():
|
| 202 |
outputs = model.generate(
|
| 203 |
**inputs,
|
| 204 |
+
max_new_tokens=50,
|
| 205 |
temperature=temperature,
|
| 206 |
top_p=0.9,
|
| 207 |
top_k=50,
|
| 208 |
repetition_penalty=1.15,
|
| 209 |
do_sample=True,
|
| 210 |
+
num_beams=1,
|
| 211 |
pad_token_id=tokenizer.eos_token_id,
|
| 212 |
eos_token_id=tokenizer.eos_token_id,
|
| 213 |
)
|
|
|
|
| 232 |
|
| 233 |
except Exception as e:
|
| 234 |
print(f"Generation error: {e}")
|
| 235 |
+
import traceback
|
| 236 |
+
traceback.print_exc()
|
| 237 |
return f"Arrey yaar! Something went wrong π€ Error: {str(e)[:100]}"
|
| 238 |
|
| 239 |
# -----------------------------
|
|
|
|
| 257 |
"lora": LORA_REPO,
|
| 258 |
"device": str(next(model.parameters()).device),
|
| 259 |
"memory_entries": len(memory),
|
| 260 |
+
"mood_detection": mood_pipe is not None,
|
| 261 |
+
"gpu_available": torch.cuda.is_available()
|
| 262 |
}
|
| 263 |
|
| 264 |
@app.post("/chat", response_model=MessageResponse)
|
| 265 |
async def chat_endpoint(req: MessageRequest):
|
| 266 |
+
"""Main chat endpoint with Zero GPU acceleration"""
|
| 267 |
+
start_time = time.time()
|
| 268 |
+
|
| 269 |
+
print(f"\n{'='*60}")
|
| 270 |
+
print(f"π΅ CHAT REQUEST at {time.strftime('%H:%M:%S')}")
|
| 271 |
+
print(f" Input: '{req.user_input[:50]}{'...' if len(req.user_input) > 50 else ''}'")
|
| 272 |
+
print(f"{'='*60}")
|
| 273 |
+
|
| 274 |
try:
|
| 275 |
if not req.user_input or not req.user_input.strip():
|
| 276 |
raise HTTPException(status_code=400, detail="Empty message")
|
| 277 |
|
| 278 |
user_text = req.user_input.strip()
|
| 279 |
|
| 280 |
+
# Step 1: Detect mood (runs on CPU)
|
| 281 |
+
print("β±οΈ [1/2] Detecting mood...")
|
| 282 |
+
mood_start = time.time()
|
| 283 |
mood = detect_mood(user_text)
|
| 284 |
+
mood_time = time.time() - mood_start
|
| 285 |
+
print(f"β
Mood: {mood} ({mood_time:.2f}s)")
|
| 286 |
|
| 287 |
+
# Step 2: Generate response (runs on GPU with @spaces.GPU)
|
| 288 |
+
print("β±οΈ [2/2] Generating response (GPU)...")
|
| 289 |
+
gen_start = time.time()
|
| 290 |
response = generate_shinchan_response(user_text, mood)
|
| 291 |
+
gen_time = time.time() - gen_start
|
| 292 |
+
print(f"β
Generated ({gen_time:.2f}s)")
|
| 293 |
+
|
| 294 |
+
total_time = time.time() - start_time
|
| 295 |
+
print(f"π TOTAL: {total_time:.2f}s")
|
| 296 |
+
print(f"{'='*60}\n")
|
| 297 |
|
| 298 |
return MessageResponse(
|
| 299 |
response=response,
|
| 300 |
mood=mood,
|
| 301 |
+
memory=memory[-10:]
|
| 302 |
)
|
| 303 |
|
| 304 |
except HTTPException:
|
| 305 |
raise
|
| 306 |
except Exception as e:
|
| 307 |
+
elapsed = time.time() - start_time
|
| 308 |
+
print(f"β ERROR after {elapsed:.2f}s: {e}")
|
| 309 |
import traceback
|
| 310 |
traceback.print_exc()
|
| 311 |
raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
|
|
|
|
| 314 |
async def reset_memory():
|
| 315 |
"""Reset conversation memory"""
|
| 316 |
global memory
|
| 317 |
+
old_size = len(memory)
|
| 318 |
memory = []
|
| 319 |
+
print(f"π Memory reset (cleared {old_size} entries)")
|
| 320 |
return {"status": "Memory cleared", "memory_size": 0}
|
| 321 |
|
| 322 |
# -----------------------------
|