Spaces:
Running on Zero
Running on Zero
Siddharth Ravikumar commited on
Commit Β·
1f69fb6
1
Parent(s): f9d2cd6
fix: make Chat Agent robust with detailed logging and GPU-context loading
Browse files- app.py +19 -8
- backend/app/core/inference.py +49 -33
app.py
CHANGED
|
@@ -60,11 +60,15 @@ inference_engine._run_inference = gpu_run_inference
|
|
| 60 |
_original_chat = chat_engine.chat
|
| 61 |
|
| 62 |
@spaces.GPU(duration=60)
|
| 63 |
-
def gpu_run_chat(system_context
|
| 64 |
-
"""GPU-accelerated chat inference"""
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
# ββ Async helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -542,14 +546,21 @@ SCENE ANALYSES:\n"""
|
|
| 542 |
def chat_respond(user_message, history, system_ctx):
|
| 543 |
if not user_message or not user_message.strip():
|
| 544 |
return history, "", system_ctx
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
|
|
|
|
|
|
| 548 |
try:
|
|
|
|
| 549 |
response = gpu_run_chat(system_ctx, user_message.strip())
|
|
|
|
| 550 |
except Exception as e:
|
|
|
|
| 551 |
response = f"Error: {e}"
|
|
|
|
| 552 |
history = history or []
|
|
|
|
| 553 |
history.append({"role": "user", "content": user_message.strip()})
|
| 554 |
history.append({"role": "assistant", "content": response})
|
| 555 |
return history, "", system_ctx
|
|
|
|
| 60 |
_original_chat = chat_engine.chat
|
| 61 |
|
| 62 |
@spaces.GPU(duration=60)
|
| 63 |
+
def gpu_run_chat(system_context, user_message):
|
| 64 |
+
"""GPU-accelerated chat inference."""
|
| 65 |
+
try:
|
| 66 |
+
# We call the engine's original method directly to avoid monkey-patch recursion
|
| 67 |
+
# And let the engine handle its own loading inside this GPU worker
|
| 68 |
+
return _original_chat(system_context, user_message)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"ZeroGPU Chat Worker Error: {e}")
|
| 71 |
+
return f"Worker Error: {e}"
|
| 72 |
|
| 73 |
|
| 74 |
# ββ Async helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 546 |
def chat_respond(user_message, history, system_ctx):
|
| 547 |
if not user_message or not user_message.strip():
|
| 548 |
return history, "", system_ctx
|
| 549 |
+
|
| 550 |
+
# ensure_init connects DB and loads rules, but not the models
|
| 551 |
+
run_async(_ensure_init())
|
| 552 |
+
|
| 553 |
+
logger.info(f"Chat request: {user_message[:50]}...")
|
| 554 |
try:
|
| 555 |
+
# Call the @spaces.GPU decorated function directly
|
| 556 |
response = gpu_run_chat(system_ctx, user_message.strip())
|
| 557 |
+
logger.info(f"Received response: {response[:50]}...")
|
| 558 |
except Exception as e:
|
| 559 |
+
logger.error(f"Chat failed: {e}")
|
| 560 |
response = f"Error: {e}"
|
| 561 |
+
|
| 562 |
history = history or []
|
| 563 |
+
# Use Gradio 5.0 message format (dict)
|
| 564 |
history.append({"role": "user", "content": user_message.strip()})
|
| 565 |
history.append({"role": "assistant", "content": response})
|
| 566 |
return history, "", system_ctx
|
backend/app/core/inference.py
CHANGED
|
@@ -253,36 +253,45 @@ class ChatEngine:
|
|
| 253 |
|
| 254 |
def load_model(self):
|
| 255 |
"""Load the text-only chat model."""
|
|
|
|
|
|
|
|
|
|
| 256 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 257 |
|
| 258 |
model_id = settings.chat_model_id
|
| 259 |
-
logger.info(f"Loading chat model: {model_id}")
|
| 260 |
-
|
| 261 |
device = settings.resolve_device()
|
| 262 |
dtype = settings.resolve_torch_dtype()
|
|
|
|
|
|
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
def chat(self, system_context: str, user_message: str) -> str:
|
| 279 |
"""
|
| 280 |
Generate a response given system context and a user question.
|
| 281 |
-
system_context: case data, traffic rules, etc.
|
| 282 |
-
user_message: the user's question
|
| 283 |
"""
|
| 284 |
if not self.is_loaded:
|
| 285 |
-
|
| 286 |
|
| 287 |
messages = [
|
| 288 |
{"role": "system", "content": system_context},
|
|
@@ -293,25 +302,32 @@ class ChatEngine:
|
|
| 293 |
text_prompt = self._tokenizer.apply_chat_template(
|
| 294 |
messages, add_generation_prompt=True, tokenize=False,
|
| 295 |
)
|
| 296 |
-
|
| 297 |
-
|
|
|
|
| 298 |
text_prompt = f"System: {system_context}\n\nUser: {user_message}\n\nAssistant:"
|
| 299 |
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
|
| 317 |
# Singleton instance
|
|
|
|
| 253 |
|
| 254 |
def load_model(self):
|
| 255 |
"""Load the text-only chat model."""
|
| 256 |
+
if self.is_loaded:
|
| 257 |
+
return
|
| 258 |
+
|
| 259 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 260 |
|
| 261 |
model_id = settings.chat_model_id
|
|
|
|
|
|
|
| 262 |
device = settings.resolve_device()
|
| 263 |
dtype = settings.resolve_torch_dtype()
|
| 264 |
+
|
| 265 |
+
logger.info(f"DEBUG: ChatEngine loading model {model_id} on {device}...")
|
| 266 |
|
| 267 |
+
try:
|
| 268 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
| 269 |
+
model_id, trust_remote_code=settings.model_trust_remote_code,
|
| 270 |
+
)
|
| 271 |
+
# Use float16 for GPU, float32 for CPU to avoid issues
|
| 272 |
+
self._model = AutoModelForCausalLM.from_pretrained(
|
| 273 |
+
model_id,
|
| 274 |
+
torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32,
|
| 275 |
+
trust_remote_code=settings.model_trust_remote_code,
|
| 276 |
+
low_cpu_mem_usage=True
|
| 277 |
+
)
|
| 278 |
|
| 279 |
+
if device != "cpu":
|
| 280 |
+
self._model = self._model.to(device)
|
| 281 |
|
| 282 |
+
self._device = device
|
| 283 |
+
self.is_loaded = True
|
| 284 |
+
logger.info(f"DEBUG: Chat model loaded successfully on {device}")
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"DEBUG ERROR: Chat model load failed: {str(e)}")
|
| 287 |
+
raise e
|
| 288 |
|
| 289 |
def chat(self, system_context: str, user_message: str) -> str:
|
| 290 |
"""
|
| 291 |
Generate a response given system context and a user question.
|
|
|
|
|
|
|
| 292 |
"""
|
| 293 |
if not self.is_loaded:
|
| 294 |
+
self.load_model()
|
| 295 |
|
| 296 |
messages = [
|
| 297 |
{"role": "system", "content": system_context},
|
|
|
|
| 302 |
text_prompt = self._tokenizer.apply_chat_template(
|
| 303 |
messages, add_generation_prompt=True, tokenize=False,
|
| 304 |
)
|
| 305 |
+
logger.info(f"DEBUG: Chat prompt prepared (length: {len(text_prompt)})")
|
| 306 |
+
except Exception as e:
|
| 307 |
+
logger.warning(f"DEBUG: Chat template failed ({e}), using fallback")
|
| 308 |
text_prompt = f"System: {system_context}\n\nUser: {user_message}\n\nAssistant:"
|
| 309 |
|
| 310 |
+
try:
|
| 311 |
+
inputs = self._tokenizer(text_prompt, return_tensors="pt").to(self._device)
|
| 312 |
+
logger.info(f"DEBUG: Inputs tokenized (length: {inputs['input_ids'].shape[1]})")
|
| 313 |
+
|
| 314 |
+
with torch.inference_mode():
|
| 315 |
+
outputs = self._model.generate(
|
| 316 |
+
**inputs,
|
| 317 |
+
max_new_tokens=512,
|
| 318 |
+
repetition_penalty=1.2,
|
| 319 |
+
temperature=0.4,
|
| 320 |
+
do_sample=True,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
prompt_length = inputs["input_ids"].shape[1]
|
| 324 |
+
generated_tokens = outputs[0][prompt_length:]
|
| 325 |
+
response = self._tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 326 |
+
logger.info(f"DEBUG: Response generated successfully (length: {len(response)})")
|
| 327 |
+
return response.strip()
|
| 328 |
+
except Exception as e:
|
| 329 |
+
logger.error(f"DEBUG ERROR: Inference failed: {str(e)}")
|
| 330 |
+
raise e
|
| 331 |
|
| 332 |
|
| 333 |
# Singleton instance
|