Siddharth Ravikumar commited on
Commit
1f69fb6
Β·
1 Parent(s): f9d2cd6

fix: make Chat Agent robust with detailed logging and GPU-context loading

Browse files
Files changed (2) hide show
  1. app.py +19 -8
  2. backend/app/core/inference.py +49 -33
app.py CHANGED
@@ -60,11 +60,15 @@ inference_engine._run_inference = gpu_run_inference
60
  _original_chat = chat_engine.chat
61
 
62
  @spaces.GPU(duration=60)
63
- def gpu_run_chat(system_context: str, user_message: str):
64
- """GPU-accelerated chat inference"""
65
- return _original_chat(system_context, user_message)
66
-
67
- chat_engine.chat = gpu_run_chat
 
 
 
 
68
 
69
 
70
  # ── Async helpers ──────────────────────────────────────────────────────
@@ -542,14 +546,21 @@ SCENE ANALYSES:\n"""
542
  def chat_respond(user_message, history, system_ctx):
543
  if not user_message or not user_message.strip():
544
  return history, "", system_ctx
545
- ensure_init()
546
- if not chat_engine.is_loaded:
547
- chat_engine.load_model()
 
 
548
  try:
 
549
  response = gpu_run_chat(system_ctx, user_message.strip())
 
550
  except Exception as e:
 
551
  response = f"Error: {e}"
 
552
  history = history or []
 
553
  history.append({"role": "user", "content": user_message.strip()})
554
  history.append({"role": "assistant", "content": response})
555
  return history, "", system_ctx
 
60
  _original_chat = chat_engine.chat
61
 
62
  @spaces.GPU(duration=60)
63
+ def gpu_run_chat(system_context, user_message):
64
+ """GPU-accelerated chat inference."""
65
+ try:
66
+ # We call the engine's original method directly to avoid monkey-patch recursion
67
+ # And let the engine handle its own loading inside this GPU worker
68
+ return _original_chat(system_context, user_message)
69
+ except Exception as e:
70
+ logger.error(f"ZeroGPU Chat Worker Error: {e}")
71
+ return f"Worker Error: {e}"
72
 
73
 
74
  # ── Async helpers ──────────────────────────────────────────────────────
 
546
  def chat_respond(user_message, history, system_ctx):
547
  if not user_message or not user_message.strip():
548
  return history, "", system_ctx
549
+
550
+ # ensure_init connects DB and loads rules, but not the models
551
+ run_async(_ensure_init())
552
+
553
+ logger.info(f"Chat request: {user_message[:50]}...")
554
  try:
555
+ # Call the @spaces.GPU decorated function directly
556
  response = gpu_run_chat(system_ctx, user_message.strip())
557
+ logger.info(f"Received response: {response[:50]}...")
558
  except Exception as e:
559
+ logger.error(f"Chat failed: {e}")
560
  response = f"Error: {e}"
561
+
562
  history = history or []
563
+ # Use Gradio 5.0 message format (dict)
564
  history.append({"role": "user", "content": user_message.strip()})
565
  history.append({"role": "assistant", "content": response})
566
  return history, "", system_ctx
backend/app/core/inference.py CHANGED
@@ -253,36 +253,45 @@ class ChatEngine:
253
 
254
  def load_model(self):
255
  """Load the text-only chat model."""
 
 
 
256
  from transformers import AutoModelForCausalLM, AutoTokenizer
257
 
258
  model_id = settings.chat_model_id
259
- logger.info(f"Loading chat model: {model_id}")
260
-
261
  device = settings.resolve_device()
262
  dtype = settings.resolve_torch_dtype()
 
 
263
 
264
- self._tokenizer = AutoTokenizer.from_pretrained(
265
- model_id, trust_remote_code=settings.model_trust_remote_code,
266
- )
267
- self._model = AutoModelForCausalLM.from_pretrained(
268
- model_id, torch_dtype=dtype, trust_remote_code=settings.model_trust_remote_code,
269
- )
 
 
 
 
 
270
 
271
- if device != "cpu":
272
- self._model = self._model.to(device)
273
 
274
- self._device = device
275
- self.is_loaded = True
276
- logger.info(f"Chat model loaded on {device}")
 
 
 
277
 
278
  def chat(self, system_context: str, user_message: str) -> str:
279
  """
280
  Generate a response given system context and a user question.
281
- system_context: case data, traffic rules, etc.
282
- user_message: the user's question
283
  """
284
  if not self.is_loaded:
285
- raise RuntimeError("Chat model not loaded. Call load_model() first.")
286
 
287
  messages = [
288
  {"role": "system", "content": system_context},
@@ -293,25 +302,32 @@ class ChatEngine:
293
  text_prompt = self._tokenizer.apply_chat_template(
294
  messages, add_generation_prompt=True, tokenize=False,
295
  )
296
- except Exception:
297
- # Fallback if no chat template
 
298
  text_prompt = f"System: {system_context}\n\nUser: {user_message}\n\nAssistant:"
299
 
300
- inputs = self._tokenizer(text_prompt, return_tensors="pt").to(self._device)
301
-
302
- with torch.inference_mode():
303
- outputs = self._model.generate(
304
- **inputs,
305
- max_new_tokens=512,
306
- repetition_penalty=1.2,
307
- temperature=0.4,
308
- do_sample=True,
309
- )
310
-
311
- prompt_length = inputs["input_ids"].shape[1]
312
- generated_tokens = outputs[0][prompt_length:]
313
- response = self._tokenizer.decode(generated_tokens, skip_special_tokens=True)
314
- return response.strip()
 
 
 
 
 
 
315
 
316
 
317
  # Singleton instance
 
253
 
254
  def load_model(self):
255
  """Load the text-only chat model."""
256
+ if self.is_loaded:
257
+ return
258
+
259
  from transformers import AutoModelForCausalLM, AutoTokenizer
260
 
261
  model_id = settings.chat_model_id
 
 
262
  device = settings.resolve_device()
263
  dtype = settings.resolve_torch_dtype()
264
+
265
+ logger.info(f"DEBUG: ChatEngine loading model {model_id} on {device}...")
266
 
267
+ try:
268
+ self._tokenizer = AutoTokenizer.from_pretrained(
269
+ model_id, trust_remote_code=settings.model_trust_remote_code,
270
+ )
271
+ # Use float16 for GPU, float32 for CPU to avoid issues
272
+ self._model = AutoModelForCausalLM.from_pretrained(
273
+ model_id,
274
+ torch_dtype=torch.float16 if "cuda" in str(device) else torch.float32,
275
+ trust_remote_code=settings.model_trust_remote_code,
276
+ low_cpu_mem_usage=True
277
+ )
278
 
279
+ if device != "cpu":
280
+ self._model = self._model.to(device)
281
 
282
+ self._device = device
283
+ self.is_loaded = True
284
+ logger.info(f"DEBUG: Chat model loaded successfully on {device}")
285
+ except Exception as e:
286
+ logger.error(f"DEBUG ERROR: Chat model load failed: {str(e)}")
287
+ raise e
288
 
289
  def chat(self, system_context: str, user_message: str) -> str:
290
  """
291
  Generate a response given system context and a user question.
 
 
292
  """
293
  if not self.is_loaded:
294
+ self.load_model()
295
 
296
  messages = [
297
  {"role": "system", "content": system_context},
 
302
  text_prompt = self._tokenizer.apply_chat_template(
303
  messages, add_generation_prompt=True, tokenize=False,
304
  )
305
+ logger.info(f"DEBUG: Chat prompt prepared (length: {len(text_prompt)})")
306
+ except Exception as e:
307
+ logger.warning(f"DEBUG: Chat template failed ({e}), using fallback")
308
  text_prompt = f"System: {system_context}\n\nUser: {user_message}\n\nAssistant:"
309
 
310
+ try:
311
+ inputs = self._tokenizer(text_prompt, return_tensors="pt").to(self._device)
312
+ logger.info(f"DEBUG: Inputs tokenized (length: {inputs['input_ids'].shape[1]})")
313
+
314
+ with torch.inference_mode():
315
+ outputs = self._model.generate(
316
+ **inputs,
317
+ max_new_tokens=512,
318
+ repetition_penalty=1.2,
319
+ temperature=0.4,
320
+ do_sample=True,
321
+ )
322
+
323
+ prompt_length = inputs["input_ids"].shape[1]
324
+ generated_tokens = outputs[0][prompt_length:]
325
+ response = self._tokenizer.decode(generated_tokens, skip_special_tokens=True)
326
+ logger.info(f"DEBUG: Response generated successfully (length: {len(response)})")
327
+ return response.strip()
328
+ except Exception as e:
329
+ logger.error(f"DEBUG ERROR: Inference failed: {str(e)}")
330
+ raise e
331
 
332
 
333
  # Singleton instance