"""Diagnose why the fine-tuned planner produces empty generations. modal run scripts/diag_planner.py """ import modal app = modal.App("cook-with-me-diag") image = ( modal.Image.debian_slim(python_version="3.12") .pip_install( "torch==2.4.0", "transformers>=4.54,<5.0", # window with BOTH CacheLayerMixin and is_torch_fx_available "huggingface_hub>=0.26,<1.0", "accelerate", "sentencepiece", ) ) hf_secret = modal.Secret.from_name("huggingface-secret") MODEL_ID = "eldinosaur/cook-with-me-planner-8b" # fine-tuned model under transformers 4.x @app.function(image=image, gpu="L4", secrets=[hf_secret], timeout=900) def diag(): import torch import transformers print("transformers version:", transformers.__version__) from transformers import AutoModelForCausalLM, AutoTokenizer print("Loading tokenizer (from base) + model (from FT)...") tok = AutoTokenizer.from_pretrained("openbmb/MiniCPM4.1-8B", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cuda" ).eval() print("has generate:", hasattr(model, "generate")) print("class mro:", [c.__name__ for c in type(model).__mro__]) prompt = ( "You are a chef. Given ingredients: tomato, onion, garlic, pasta, olive oil.\n" 'Return ONLY JSON: {"options": [{"name": "...", "why": "..."}, ...]} with 3 dish ideas.' ) messages = [{"role": "user", "content": prompt}] # Mirror the fixed planner.py path try: enc = tok.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, ) input_ids = enc["input_ids"].to("cuda") input_len = input_ids.shape[1] gen_inputs = {"input_ids": input_ids} if enc.get("attention_mask") is not None: gen_inputs["attention_mask"] = enc["attention_mask"].to("cuda") print("input length:", input_len) with torch.no_grad(): out = model.generate(**gen_inputs, max_new_tokens=400, do_sample=False) text = tok.decode(out[0][input_len:], skip_special_tokens=True) print("=== GENERATION OK (transformers 4.x, cache on) ===") print("OUTPUT:", repr(text[:1000])) except Exception as e: import traceback print("=== GENERATION FAILED ===") print("Exception type:", type(e).__name__) print("Exception repr:", repr(e)) traceback.print_exc() @app.local_entrypoint() def main(): diag.remote()