Files changed (1) hide show
  1. handler.py +26 -1
handler.py CHANGED
@@ -351,9 +351,11 @@ def generate_response(message_text, image_input, max_output_tokens=4096, repetit
351
  try:
352
  if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE:
353
  our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
 
354
  else:
355
  # Use default conversation template
356
  our_chatbot.conversation = our_chatbot.conversation.__class__()
 
357
  except Exception as e:
358
  print(f"[DEBUG] Failed to reset conversation: {e}")
359
  # Continue with existing conversation
@@ -363,6 +365,10 @@ def generate_response(message_text, image_input, max_output_tokens=4096, repetit
363
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
364
  prompt = our_chatbot.conversation.get_prompt()
365
 
 
 
 
 
366
  # Tokenize input
367
  input_ids = tokenizer_image_token(
368
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
@@ -370,6 +376,8 @@ def generate_response(message_text, image_input, max_output_tokens=4096, repetit
370
 
371
  # No stopping criteria - let model generate freely up to max_new_tokens
372
  print(f"[DEBUG] No stopping criteria - free generation up to {max_output_tokens} tokens")
 
 
373
  stopping_criteria = None
374
 
375
  # Set seed for deterministic generation
@@ -381,6 +389,10 @@ def generate_response(message_text, image_input, max_output_tokens=4096, repetit
381
 
382
  # Generate response using deterministic greedy decoding
383
  # This eliminates randomness and ensures consistent responses
 
 
 
 
384
  with torch.no_grad():
385
  outputs = our_chatbot.model.generate(
386
  inputs=input_ids,
@@ -392,6 +404,7 @@ def generate_response(message_text, image_input, max_output_tokens=4096, repetit
392
  pad_token_id=our_chatbot.tokenizer.eos_token_id,
393
  eos_token_id=our_chatbot.tokenizer.eos_token_id,
394
  length_penalty=1.0, # Don't penalize longer sequences
 
395
  )
396
 
397
  # Decode response
@@ -399,6 +412,8 @@ def generate_response(message_text, image_input, max_output_tokens=4096, repetit
399
  print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}")
400
  print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}")
401
  print(f"[DEBUG] Input IDs shape: {input_ids.shape}")
 
 
402
 
403
  if len(outputs) == 0:
404
  return {"error": "Model generated empty output"}
@@ -509,7 +524,7 @@ def initialize_model():
509
  self.model_base = None
510
  self.num_gpus = 1
511
  self.conv_mode = None
512
- self.max_new_tokens = 1024
513
  self.num_frames = 16
514
  self.load_8bit = False
515
  self.load_4bit = False
@@ -589,6 +604,16 @@ def query(payload):
589
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
590
  conv_mode_override = payload.get("conv_mode", None)
591
 
 
 
 
 
 
 
 
 
 
 
592
  if not message_text or not message_text.strip():
593
  return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
594
 
 
351
  try:
352
  if hasattr(our_chatbot, 'conv_mode') and our_chatbot.conv_mode and LLAVA_AVAILABLE:
353
  our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
354
+ print(f"[DEBUG] Reset conversation using conv_mode: {our_chatbot.conv_mode}")
355
  else:
356
  # Use default conversation template
357
  our_chatbot.conversation = our_chatbot.conversation.__class__()
358
+ print(f"[DEBUG] Reset conversation using default template")
359
  except Exception as e:
360
  print(f"[DEBUG] Failed to reset conversation: {e}")
361
  # Continue with existing conversation
 
365
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
366
  prompt = our_chatbot.conversation.get_prompt()
367
 
368
+ print(f"[DEBUG] Conversation template: {type(our_chatbot.conversation).__name__}")
369
+ print(f"[DEBUG] Conversation roles: {our_chatbot.conversation.roles if hasattr(our_chatbot.conversation, 'roles') else 'No roles'}")
370
+ print(f"[DEBUG] Final prompt length: {len(prompt)} characters")
371
+
372
  # Tokenize input
373
  input_ids = tokenizer_image_token(
374
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
 
376
 
377
  # No stopping criteria - let model generate freely up to max_new_tokens
378
  print(f"[DEBUG] No stopping criteria - free generation up to {max_output_tokens} tokens")
379
+ print(f"[DEBUG] Input prompt length: {len(prompt)} characters")
380
+ print(f"[DEBUG] Input tokens: {input_ids.shape[1]} tokens")
381
  stopping_criteria = None
382
 
383
  # Set seed for deterministic generation
 
389
 
390
  # Generate response using deterministic greedy decoding
391
  # This eliminates randomness and ensures consistent responses
392
+ print(f"[DEBUG] About to generate with max_new_tokens: {max_output_tokens}")
393
+ print(f"[DEBUG] Model device: {our_chatbot.model.device}")
394
+ print(f"[DEBUG] Image tensor device: {image_tensor.device}")
395
+
396
  with torch.no_grad():
397
  outputs = our_chatbot.model.generate(
398
  inputs=input_ids,
 
404
  pad_token_id=our_chatbot.tokenizer.eos_token_id,
405
  eos_token_id=our_chatbot.tokenizer.eos_token_id,
406
  length_penalty=1.0, # Don't penalize longer sequences
407
+ early_stopping=False, # Ensure no early stopping
408
  )
409
 
410
  # Decode response
 
412
  print(f"[DEBUG] Outputs shape: {outputs.shape if hasattr(outputs, 'shape') else 'No shape attr'}")
413
  print(f"[DEBUG] Outputs length: {len(outputs) if hasattr(outputs, '__len__') else 'No length'}")
414
  print(f"[DEBUG] Input IDs shape: {input_ids.shape}")
415
+ print(f"[DEBUG] Generated tokens: {outputs.shape[1] - input_ids.shape[1] if hasattr(outputs, 'shape') else 'Unknown'}")
416
+ print(f"[DEBUG] Expected max tokens: {max_output_tokens}")
417
 
418
  if len(outputs) == 0:
419
  return {"error": "Model generated empty output"}
 
524
  self.model_base = None
525
  self.num_gpus = 1
526
  self.conv_mode = None
527
+ self.max_new_tokens = 4096
528
  self.num_frames = 16
529
  self.load_8bit = False
530
  self.load_4bit = False
 
604
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
605
  conv_mode_override = payload.get("conv_mode", None)
606
 
607
+ # Debug: Log all generation parameters
608
+ print(f"[DEBUG] Generation parameters:")
609
+ print(f"[DEBUG] max_output_tokens: {max_output_tokens}")
610
+ print(f"[DEBUG] repetition_penalty: {repetition_penalty}")
611
+ print(f"[DEBUG] Original payload max_output_tokens: {payload.get('max_output_tokens')}")
612
+ print(f"[DEBUG] Original payload max_new_tokens: {payload.get('max_new_tokens')}")
613
+ print(f"[DEBUG] Original payload max_tokens: {payload.get('max_tokens')}")
614
+ print(f"[DEBUG] Full payload keys: {list(payload.keys())}")
615
+ print(f"[DEBUG] Payload values: {dict(payload)}")
616
+
617
  if not message_text or not message_text.strip():
618
  return {"error": "Missing prompt text. Use 'message', 'query', 'prompt', or 'istem' key"}
619