Alikestocode commited on
Commit
03689e3
Β·
1 Parent(s): 06aef1b

Fix Gradio UI structure and add comprehensive fallback logging

Browse files

- Fix UI structure: move buttons to separate Row, fix Column nesting
- Add detailed fallback chain logging with emoji indicators
- Show clear progression: vLLM β†’ Transformers AWQ β†’ BitsAndBytes β†’ FP16/FP32
- Improve error messages to show which fallback path is being used
- All fallback paths now properly logged for debugging

Files changed (1) hide show
  1. app.py +81 -59
app.py CHANGED
@@ -212,41 +212,54 @@ def load_awq_pipeline(repo: str, tokenizer):
212
 
213
 
214
  def load_pipeline(model_name: str):
215
- """Load model with vLLM (preferred) or Transformers (fallback)."""
 
 
 
 
 
 
 
 
216
  # Try vLLM first (best performance with native AWQ support via llm-compressor)
217
  # vLLM handles AWQ natively, so AutoAWQ deprecation doesn't affect us
218
  if VLLM_AVAILABLE:
219
  try:
220
- print(f"Attempting to load {model_name} with vLLM (native AWQ support)...")
221
  return load_vllm_model(model_name)
222
  except Exception as exc:
223
- print(f"⚠️ vLLM load failed, falling back to Transformers: {exc}")
 
224
  import traceback
225
  traceback.print_exc()
226
 
227
  # Fallback to Transformers pipeline
228
  if model_name in PIPELINES:
 
229
  return PIPELINES[model_name]
230
 
231
  repo = MODELS[model_name]["repo_id"]
232
  tokenizer = get_tokenizer(repo)
233
 
234
- # Try AWQ first if available
235
  if AWQ_AVAILABLE:
236
  try:
237
- print(f"Loading {repo} with AWQ quantization...")
238
  pipe = load_awq_pipeline(repo, tokenizer)
239
  PIPELINES[model_name] = pipe
240
  _schedule_background_warm(model_name)
241
  # Warm kernels immediately after loading
242
  Thread(target=lambda: _warm_kernels(model_name), daemon=True).start()
 
243
  return pipe
244
  except Exception as exc:
245
- print(f"AWQ load failed for {repo}: {exc}. Falling back to BitsAndBytes.")
 
246
 
247
  # Fallback to BitsAndBytes 8-bit
248
  if BITSANDBYTES_AVAILABLE:
249
  try:
 
250
  quant_config = BitsAndBytesConfig(load_in_8bit=True)
251
  model_kwargs = {"quantization_config": quant_config}
252
  if FLASH_ATTN_AVAILABLE:
@@ -275,13 +288,17 @@ def load_pipeline(model_name: str):
275
 
276
  PIPELINES[model_name] = pipe
277
  _schedule_background_warm(model_name)
 
278
  return pipe
279
  except Exception as exc:
280
- print(f"8-bit load failed for {repo}: {exc}. Falling back to higher precision.")
 
281
 
282
- # Fallback to bfloat16/fp16/fp32
283
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
 
284
  try:
 
285
  model_kwargs = {}
286
  if FLASH_ATTN_AVAILABLE:
287
  model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -308,11 +325,14 @@ def load_pipeline(model_name: str):
308
 
309
  PIPELINES[model_name] = pipe
310
  _schedule_background_warm(model_name)
 
311
  return pipe
312
- except Exception:
 
313
  continue
314
 
315
- # Final fallback
 
316
  model_kwargs = {}
317
  if FLASH_ATTN_AVAILABLE:
318
  model_kwargs["attn_implementation"] = "flash_attention_2"
@@ -338,6 +358,7 @@ def load_pipeline(model_name: str):
338
 
339
  PIPELINES[model_name] = pipe
340
  _schedule_background_warm(model_name)
 
341
  return pipe
342
 
343
 
@@ -788,56 +809,57 @@ def build_ui():
788
  """) as demo:
789
  gr.Markdown("# πŸ›°οΈ Router Control Room β€” ZeroGPU" )
790
  gr.Markdown(description)
791
-
792
- with gr.Row():
793
- with gr.Column(scale=3):
794
- user_task = gr.Textbox(
795
- label="User Task / Problem Statement",
796
- placeholder="Describe the homework-style query that needs routing...",
797
- lines=8,
798
- value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.",
799
- )
800
- context = gr.Textbox(
801
- label="Supporting Context (optional)",
802
- placeholder="Paste any retrieved evidence, PDFs, or rubric notes.",
803
- lines=4,
804
- )
805
- acceptance = gr.Textbox(
806
- label="Acceptance Criteria",
807
- placeholder="Bullet list of 'definition of done' checks.",
808
- lines=3,
809
- value="- Provide citations for every claim.\n- Ensure /math verifies /code output.",
810
- )
811
- extra_guidance = gr.Textbox(
812
- label="Additional Guidance",
813
- placeholder="Special constraints, tools to avoid, etc.",
814
- lines=3,
815
- )
816
- with gr.Column(scale=2):
817
- model_choice = gr.Dropdown(
818
- label="Router Checkpoint",
819
- choices=list(MODELS.keys()),
820
- value=list(MODELS.keys())[0] if MODELS else None,
821
- allow_custom_value=False,
822
- )
823
- difficulty = gr.Radio(
824
- label="Difficulty Tier",
825
- choices=["introductory", "intermediate", "advanced"],
826
- value="advanced",
827
- interactive=True,
828
- )
829
- tags = gr.Textbox(
830
- label="Tags",
831
- placeholder="Comma-separated e.g. calculus, optimization, python",
832
- value="calculus, optimization, python",
833
- )
834
- max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
835
- temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
836
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
837
- gpu_duration = gr.Slider(60, 1800, value=600, step=60, label="GPU Duration (seconds)", info="Maximum GPU time allocation for this request")
838
 
839
- generate_btn = gr.Button("Generate Router Plan", variant="primary")
840
- clear_btn = gr.Button("Clear", variant="secondary")
 
841
 
842
  with gr.Row():
843
  raw_output = gr.Textbox(label="Raw Model Output", lines=12)
 
212
 
213
 
214
  def load_pipeline(model_name: str):
215
+ """Load model with vLLM (preferred) or Transformers (fallback).
216
+
217
+ Fallback chain:
218
+ 1. vLLM with AWQ (best performance, continuous batching)
219
+ 2. vLLM with FP16 (if AWQ not available)
220
+ 3. Transformers with AWQ (via AutoAWQ - deprecated but functional)
221
+ 4. Transformers with BitsAndBytes 8-bit
222
+ 5. Transformers with FP16/FP32
223
+ """
224
  # Try vLLM first (best performance with native AWQ support via llm-compressor)
225
  # vLLM handles AWQ natively, so AutoAWQ deprecation doesn't affect us
226
  if VLLM_AVAILABLE:
227
  try:
228
+ print(f"πŸ”„ Attempting to load {model_name} with vLLM (native AWQ support)...")
229
  return load_vllm_model(model_name)
230
  except Exception as exc:
231
+ print(f"⚠️ vLLM load failed: {exc}")
232
+ print(f" β†’ Falling back to Transformers pipeline...")
233
  import traceback
234
  traceback.print_exc()
235
 
236
  # Fallback to Transformers pipeline
237
  if model_name in PIPELINES:
238
+ print(f"βœ… Using cached Transformers pipeline for {model_name}")
239
  return PIPELINES[model_name]
240
 
241
  repo = MODELS[model_name]["repo_id"]
242
  tokenizer = get_tokenizer(repo)
243
 
244
+ # Try AWQ first if available (Transformers fallback path)
245
  if AWQ_AVAILABLE:
246
  try:
247
+ print(f"πŸ”„ Loading {repo} with Transformers + AutoAWQ (fallback path)...")
248
  pipe = load_awq_pipeline(repo, tokenizer)
249
  PIPELINES[model_name] = pipe
250
  _schedule_background_warm(model_name)
251
  # Warm kernels immediately after loading
252
  Thread(target=lambda: _warm_kernels(model_name), daemon=True).start()
253
+ print(f"βœ… Transformers + AutoAWQ pipeline loaded: {model_name}")
254
  return pipe
255
  except Exception as exc:
256
+ print(f"⚠️ AutoAWQ load failed for {repo}: {exc}")
257
+ print(f" β†’ Falling back to BitsAndBytes 8-bit...")
258
 
259
  # Fallback to BitsAndBytes 8-bit
260
  if BITSANDBYTES_AVAILABLE:
261
  try:
262
+ print(f"πŸ”„ Loading {repo} with BitsAndBytes 8-bit quantization...")
263
  quant_config = BitsAndBytesConfig(load_in_8bit=True)
264
  model_kwargs = {"quantization_config": quant_config}
265
  if FLASH_ATTN_AVAILABLE:
 
288
 
289
  PIPELINES[model_name] = pipe
290
  _schedule_background_warm(model_name)
291
+ print(f"βœ… BitsAndBytes 8-bit pipeline loaded: {model_name}")
292
  return pipe
293
  except Exception as exc:
294
+ print(f"⚠️ BitsAndBytes 8-bit load failed for {repo}: {exc}")
295
+ print(f" β†’ Falling back to FP16/FP32...")
296
 
297
+ # Fallback to bfloat16/fp16/fp32 (unquantized)
298
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
299
+ dtype_name = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"}[dtype]
300
  try:
301
+ print(f"πŸ”„ Loading {repo} with {dtype_name} precision...")
302
  model_kwargs = {}
303
  if FLASH_ATTN_AVAILABLE:
304
  model_kwargs["attn_implementation"] = "flash_attention_2"
 
325
 
326
  PIPELINES[model_name] = pipe
327
  _schedule_background_warm(model_name)
328
+ print(f"βœ… {dtype_name} pipeline loaded: {model_name}")
329
  return pipe
330
+ except Exception as exc:
331
+ print(f"⚠️ {dtype_name} load failed: {exc}")
332
  continue
333
 
334
+ # Final fallback (no quantization, no FlashAttention)
335
+ print(f"⚠️ All quantization methods failed, using basic pipeline...")
336
  model_kwargs = {}
337
  if FLASH_ATTN_AVAILABLE:
338
  model_kwargs["attn_implementation"] = "flash_attention_2"
 
358
 
359
  PIPELINES[model_name] = pipe
360
  _schedule_background_warm(model_name)
361
+ print(f"βœ… Basic pipeline loaded: {model_name}")
362
  return pipe
363
 
364
 
 
809
  """) as demo:
810
  gr.Markdown("# πŸ›°οΈ Router Control Room β€” ZeroGPU" )
811
  gr.Markdown(description)
812
+
813
+ with gr.Row():
814
+ with gr.Column(scale=3):
815
+ user_task = gr.Textbox(
816
+ label="User Task / Problem Statement",
817
+ placeholder="Describe the homework-style query that needs routing...",
818
+ lines=8,
819
+ value="Explain how to solve a constrained optimization homework problem that mixes calculus and coding steps.",
820
+ )
821
+ context = gr.Textbox(
822
+ label="Supporting Context (optional)",
823
+ placeholder="Paste any retrieved evidence, PDFs, or rubric notes.",
824
+ lines=4,
825
+ )
826
+ acceptance = gr.Textbox(
827
+ label="Acceptance Criteria",
828
+ placeholder="Bullet list of 'definition of done' checks.",
829
+ lines=3,
830
+ value="- Provide citations for every claim.\n- Ensure /math verifies /code output.",
831
+ )
832
+ extra_guidance = gr.Textbox(
833
+ label="Additional Guidance",
834
+ placeholder="Special constraints, tools to avoid, etc.",
835
+ lines=3,
836
+ )
837
+ with gr.Column(scale=2):
838
+ model_choice = gr.Dropdown(
839
+ label="Router Checkpoint",
840
+ choices=list(MODELS.keys()),
841
+ value=list(MODELS.keys())[0] if MODELS else None,
842
+ allow_custom_value=False,
843
+ )
844
+ difficulty = gr.Radio(
845
+ label="Difficulty Tier",
846
+ choices=["introductory", "intermediate", "advanced"],
847
+ value="advanced",
848
+ interactive=True,
849
+ )
850
+ tags = gr.Textbox(
851
+ label="Tags",
852
+ placeholder="Comma-separated e.g. calculus, optimization, python",
853
+ value="calculus, optimization, python",
854
+ )
855
+ max_new_tokens = gr.Slider(256, 20000, value=16000, step=32, label="Max New Tokens")
856
+ temperature = gr.Slider(0.0, 1.5, value=0.2, step=0.05, label="Temperature")
857
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
858
+ gpu_duration = gr.Slider(60, 1800, value=600, step=60, label="GPU Duration (seconds)", info="Maximum GPU time allocation for this request")
859
 
860
+ with gr.Row():
861
+ generate_btn = gr.Button("Generate Router Plan", variant="primary", scale=1)
862
+ clear_btn = gr.Button("Clear", variant="secondary", scale=1)
863
 
864
  with gr.Row():
865
  raw_output = gr.Textbox(label="Raw Model Output", lines=12)