tadaGoel commited on
Commit
51da688
Β·
verified Β·
1 Parent(s): 43774a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -13
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
4
  from peft import PeftModel
5
  from fastapi import FastAPI, HTTPException
@@ -7,6 +8,7 @@ from pydantic import BaseModel
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from typing import List
9
  import gc
 
10
 
11
  # -----------------------------
12
  # CONFIG
@@ -71,8 +73,8 @@ try:
71
  token=HF_TOKEN,
72
  low_cpu_mem_usage=True,
73
  trust_remote_code=True,
74
- offload_folder=offload_dir, # Enable disk offloading
75
- offload_state_dict=True, # Offload state dict to disk
76
  )
77
 
78
  print("βœ… Base model loaded in 4-bit!")
@@ -83,7 +85,7 @@ try:
83
  base_model,
84
  LORA_REPO,
85
  token=HF_TOKEN,
86
- offload_folder=offload_dir, # Use same offload directory
87
  )
88
 
89
  print("βœ… LoRA adapter loaded!")
@@ -170,11 +172,13 @@ def detect_mood(text: str) -> str:
170
  print(f"Mood detection error: {e}")
171
  return "neutral"
172
 
 
 
173
  def generate_shinchan_response(user_input: str, mood: str) -> str:
174
  """Generate Shinchan's response based on user input and mood"""
175
  try:
176
- # Add conversation context (last 3 exchanges)
177
- context = "\n".join(memory[-6:]) if memory else ""
178
 
179
  prompt = (
180
  f"<s>[INST] {SYS_PROMPT}\n"
@@ -187,7 +191,7 @@ def generate_shinchan_response(user_input: str, mood: str) -> str:
187
  prompt,
188
  return_tensors="pt",
189
  truncation=True,
190
- max_length=512,
191
  padding=True
192
  )
193
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
@@ -197,12 +201,13 @@ def generate_shinchan_response(user_input: str, mood: str) -> str:
197
  with torch.no_grad():
198
  outputs = model.generate(
199
  **inputs,
200
- max_new_tokens=80,
201
  temperature=temperature,
202
  top_p=0.9,
203
  top_k=50,
204
  repetition_penalty=1.15,
205
  do_sample=True,
 
206
  pad_token_id=tokenizer.eos_token_id,
207
  eos_token_id=tokenizer.eos_token_id,
208
  )
@@ -227,6 +232,8 @@ def generate_shinchan_response(user_input: str, mood: str) -> str:
227
 
228
  except Exception as e:
229
  print(f"Generation error: {e}")
 
 
230
  return f"Arrey yaar! Something went wrong πŸ€• Error: {str(e)[:100]}"
231
 
232
  # -----------------------------
@@ -250,34 +257,55 @@ async def health():
250
  "lora": LORA_REPO,
251
  "device": str(next(model.parameters()).device),
252
  "memory_entries": len(memory),
253
- "mood_detection": mood_pipe is not None
 
254
  }
255
 
256
  @app.post("/chat", response_model=MessageResponse)
257
  async def chat_endpoint(req: MessageRequest):
258
- """Main chat endpoint"""
 
 
 
 
 
 
 
259
  try:
260
  if not req.user_input or not req.user_input.strip():
261
  raise HTTPException(status_code=400, detail="Empty message")
262
 
263
  user_text = req.user_input.strip()
264
 
265
- # Detect mood
 
 
266
  mood = detect_mood(user_text)
 
 
267
 
268
- # Generate response
 
 
269
  response = generate_shinchan_response(user_text, mood)
 
 
 
 
 
 
270
 
271
  return MessageResponse(
272
  response=response,
273
  mood=mood,
274
- memory=memory[-10:] # Return last 5 exchanges
275
  )
276
 
277
  except HTTPException:
278
  raise
279
  except Exception as e:
280
- print(f"Chat endpoint error: {e}")
 
281
  import traceback
282
  traceback.print_exc()
283
  raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
@@ -286,7 +314,9 @@ async def chat_endpoint(req: MessageRequest):
286
  async def reset_memory():
287
  """Reset conversation memory"""
288
  global memory
 
289
  memory = []
 
290
  return {"status": "Memory cleared", "memory_size": 0}
291
 
292
  # -----------------------------
 
1
  import os
2
  import torch
3
+ import time
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
5
  from peft import PeftModel
6
  from fastapi import FastAPI, HTTPException
 
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from typing import List
10
  import gc
11
+ import spaces # βœ… Import for Zero GPU
12
 
13
  # -----------------------------
14
  # CONFIG
 
73
  token=HF_TOKEN,
74
  low_cpu_mem_usage=True,
75
  trust_remote_code=True,
76
+ offload_folder=offload_dir,
77
+ offload_state_dict=True,
78
  )
79
 
80
  print("βœ… Base model loaded in 4-bit!")
 
85
  base_model,
86
  LORA_REPO,
87
  token=HF_TOKEN,
88
+ offload_folder=offload_dir,
89
  )
90
 
91
  print("βœ… LoRA adapter loaded!")
 
172
  print(f"Mood detection error: {e}")
173
  return "neutral"
174
 
175
+ # βœ… ZERO GPU DECORATOR - This gets you FREE GPU!
176
+ @spaces.GPU(duration=60) # Max 60 seconds per request
177
  def generate_shinchan_response(user_input: str, mood: str) -> str:
178
  """Generate Shinchan's response based on user input and mood"""
179
  try:
180
+ # Use last 2 exchanges (4 entries) for context
181
+ context = "\n".join(memory[-4:]) if memory else ""
182
 
183
  prompt = (
184
  f"<s>[INST] {SYS_PROMPT}\n"
 
191
  prompt,
192
  return_tensors="pt",
193
  truncation=True,
194
+ max_length=384,
195
  padding=True
196
  )
197
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
201
  with torch.no_grad():
202
  outputs = model.generate(
203
  **inputs,
204
+ max_new_tokens=50,
205
  temperature=temperature,
206
  top_p=0.9,
207
  top_k=50,
208
  repetition_penalty=1.15,
209
  do_sample=True,
210
+ num_beams=1,
211
  pad_token_id=tokenizer.eos_token_id,
212
  eos_token_id=tokenizer.eos_token_id,
213
  )
 
232
 
233
  except Exception as e:
234
  print(f"Generation error: {e}")
235
+ import traceback
236
+ traceback.print_exc()
237
  return f"Arrey yaar! Something went wrong πŸ€• Error: {str(e)[:100]}"
238
 
239
  # -----------------------------
 
257
  "lora": LORA_REPO,
258
  "device": str(next(model.parameters()).device),
259
  "memory_entries": len(memory),
260
+ "mood_detection": mood_pipe is not None,
261
+ "gpu_available": torch.cuda.is_available()
262
  }
263
 
264
  @app.post("/chat", response_model=MessageResponse)
265
  async def chat_endpoint(req: MessageRequest):
266
+ """Main chat endpoint with Zero GPU acceleration"""
267
+ start_time = time.time()
268
+
269
+ print(f"\n{'='*60}")
270
+ print(f"πŸ”΅ CHAT REQUEST at {time.strftime('%H:%M:%S')}")
271
+ print(f" Input: '{req.user_input[:50]}{'...' if len(req.user_input) > 50 else ''}'")
272
+ print(f"{'='*60}")
273
+
274
  try:
275
  if not req.user_input or not req.user_input.strip():
276
  raise HTTPException(status_code=400, detail="Empty message")
277
 
278
  user_text = req.user_input.strip()
279
 
280
+ # Step 1: Detect mood (runs on CPU)
281
+ print("⏱️ [1/2] Detecting mood...")
282
+ mood_start = time.time()
283
  mood = detect_mood(user_text)
284
+ mood_time = time.time() - mood_start
285
+ print(f"βœ… Mood: {mood} ({mood_time:.2f}s)")
286
 
287
+ # Step 2: Generate response (runs on GPU with @spaces.GPU)
288
+ print("⏱️ [2/2] Generating response (GPU)...")
289
+ gen_start = time.time()
290
  response = generate_shinchan_response(user_text, mood)
291
+ gen_time = time.time() - gen_start
292
+ print(f"βœ… Generated ({gen_time:.2f}s)")
293
+
294
+ total_time = time.time() - start_time
295
+ print(f"πŸŽ‰ TOTAL: {total_time:.2f}s")
296
+ print(f"{'='*60}\n")
297
 
298
  return MessageResponse(
299
  response=response,
300
  mood=mood,
301
+ memory=memory[-10:]
302
  )
303
 
304
  except HTTPException:
305
  raise
306
  except Exception as e:
307
+ elapsed = time.time() - start_time
308
+ print(f"❌ ERROR after {elapsed:.2f}s: {e}")
309
  import traceback
310
  traceback.print_exc()
311
  raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
 
314
  async def reset_memory():
315
  """Reset conversation memory"""
316
  global memory
317
+ old_size = len(memory)
318
  memory = []
319
+ print(f"πŸ”„ Memory reset (cleared {old_size} entries)")
320
  return {"status": "Memory cleared", "memory_size": 0}
321
 
322
  # -----------------------------