Spaces:
Sleeping
Sleeping
Commit
·
b4fd5e9
1
Parent(s):
54880b1
Fix vLLM token parameter and improve streaming error handling
Browse files- Remove 'token' parameter from vLLM LLM() call (uses HF_TOKEN env var)
- Add better error handling for generation thread
- Add debug logging for streamer token consumption
- Add timeout handling for generation thread
app.py
CHANGED
|
@@ -138,10 +138,10 @@ def load_vllm_model(model_name: str):
|
|
| 138 |
try:
|
| 139 |
# vLLM configuration optimized for ZeroGPU H200 slice
|
| 140 |
# vLLM natively supports AWQ via llm-compressor (replaces deprecated AutoAWQ)
|
|
|
|
| 141 |
llm_kwargs = {
|
| 142 |
"model": repo,
|
| 143 |
"trust_remote_code": True,
|
| 144 |
-
"token": HF_TOKEN,
|
| 145 |
"dtype": "bfloat16", # Prefer bf16 over int8 for speed
|
| 146 |
"gpu_memory_utilization": 0.90, # Leave headroom for KV cache
|
| 147 |
"max_model_len": 16384, # Adjust based on GPU memory
|
|
@@ -675,45 +675,79 @@ def _generate_router_plan_streaming_internal(
|
|
| 675 |
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 676 |
}
|
| 677 |
|
|
|
|
|
|
|
| 678 |
def _generate():
|
| 679 |
-
|
| 680 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
thread = Thread(target=_generate)
|
| 683 |
thread.start()
|
| 684 |
|
| 685 |
-
|
| 686 |
completion = ""
|
| 687 |
parsed_plan: Dict[str, Any] | None = None
|
| 688 |
validation_msg = "🔄 Generating..."
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
# Final processing after streaming completes
|
| 716 |
-
thread.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
|
| 718 |
completion = trim_at_stop_sequences(completion.strip())[0]
|
| 719 |
print(f"[DEBUG] Final completion length: {len(completion)}")
|
|
|
|
| 138 |
try:
|
| 139 |
# vLLM configuration optimized for ZeroGPU H200 slice
|
| 140 |
# vLLM natively supports AWQ via llm-compressor (replaces deprecated AutoAWQ)
|
| 141 |
+
# Note: HF_TOKEN is passed via environment variable, not as a parameter
|
| 142 |
llm_kwargs = {
|
| 143 |
"model": repo,
|
| 144 |
"trust_remote_code": True,
|
|
|
|
| 145 |
"dtype": "bfloat16", # Prefer bf16 over int8 for speed
|
| 146 |
"gpu_memory_utilization": 0.90, # Leave headroom for KV cache
|
| 147 |
"max_model_len": 16384, # Adjust based on GPU memory
|
|
|
|
| 675 |
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
| 676 |
}
|
| 677 |
|
| 678 |
+
generation_error = None
|
| 679 |
+
|
| 680 |
def _generate():
|
| 681 |
+
nonlocal generation_error
|
| 682 |
+
try:
|
| 683 |
+
with torch.inference_mode():
|
| 684 |
+
model.generate(**generation_kwargs)
|
| 685 |
+
except Exception as e:
|
| 686 |
+
generation_error = e
|
| 687 |
+
print(f"[DEBUG] Generation thread error: {e}")
|
| 688 |
+
import traceback
|
| 689 |
+
traceback.print_exc()
|
| 690 |
|
| 691 |
thread = Thread(target=_generate)
|
| 692 |
thread.start()
|
| 693 |
|
| 694 |
+
# Stream tokens
|
| 695 |
completion = ""
|
| 696 |
parsed_plan: Dict[str, Any] | None = None
|
| 697 |
validation_msg = "🔄 Generating..."
|
| 698 |
+
|
| 699 |
+
print(f"[DEBUG] Starting to consume streamer...")
|
| 700 |
+
token_count = 0
|
| 701 |
+
|
| 702 |
+
try:
|
| 703 |
+
for new_text in streamer:
|
| 704 |
+
if generation_error:
|
| 705 |
+
raise generation_error
|
| 706 |
+
|
| 707 |
+
if new_text:
|
| 708 |
+
token_count += 1
|
| 709 |
+
completion += new_text
|
| 710 |
+
chunk = completion
|
| 711 |
+
finished = False
|
| 712 |
+
display_plan = parsed_plan or {}
|
| 713 |
+
|
| 714 |
+
chunk, finished = trim_at_stop_sequences(chunk)
|
| 715 |
+
|
| 716 |
+
try:
|
| 717 |
+
json_block = extract_json_from_text(chunk)
|
| 718 |
+
candidate_plan = json.loads(json_block)
|
| 719 |
+
ok, issues = validate_router_plan(candidate_plan)
|
| 720 |
+
validation_msg = format_validation_message(ok, issues)
|
| 721 |
+
parsed_plan = candidate_plan if ok else parsed_plan
|
| 722 |
+
display_plan = candidate_plan
|
| 723 |
+
except Exception:
|
| 724 |
+
# Ignore until JSON is complete
|
| 725 |
+
pass
|
| 726 |
+
|
| 727 |
+
yield chunk, display_plan, validation_msg, prompt
|
| 728 |
+
|
| 729 |
+
if finished:
|
| 730 |
+
completion = chunk
|
| 731 |
+
break
|
| 732 |
+
|
| 733 |
+
print(f"[DEBUG] Streamer finished. Received {token_count} tokens.")
|
| 734 |
+
except Exception as stream_error:
|
| 735 |
+
print(f"[DEBUG] Streamer error: {stream_error}")
|
| 736 |
+
import traceback
|
| 737 |
+
traceback.print_exc()
|
| 738 |
+
# Wait for thread to finish
|
| 739 |
+
thread.join(timeout=5.0)
|
| 740 |
+
if generation_error:
|
| 741 |
+
raise generation_error
|
| 742 |
+
raise stream_error
|
| 743 |
|
| 744 |
# Final processing after streaming completes
|
| 745 |
+
thread.join(timeout=30.0)
|
| 746 |
+
if thread.is_alive():
|
| 747 |
+
print("[DEBUG] WARNING: Generation thread still running after timeout")
|
| 748 |
+
|
| 749 |
+
if generation_error:
|
| 750 |
+
raise generation_error
|
| 751 |
|
| 752 |
completion = trim_at_stop_sequences(completion.strip())[0]
|
| 753 |
print(f"[DEBUG] Final completion length: {len(completion)}")
|