Cook_with_a_LLM / scripts /diag_planner.py
Fred1e4's picture
Complete Cook App (#5)
75c5414
raw
history blame contribute delete
2.64 kB
"""Diagnose why the fine-tuned planner produces empty generations.
modal run scripts/diag_planner.py
"""
import modal
app = modal.App("cook-with-me-diag")
image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"torch==2.4.0",
"transformers>=4.54,<5.0", # window with BOTH CacheLayerMixin and is_torch_fx_available
"huggingface_hub>=0.26,<1.0",
"accelerate",
"sentencepiece",
)
)
hf_secret = modal.Secret.from_name("huggingface-secret")
MODEL_ID = "eldinosaur/cook-with-me-planner-8b" # fine-tuned model under transformers 4.x
@app.function(image=image, gpu="L4", secrets=[hf_secret], timeout=900)
def diag():
import torch
import transformers
print("transformers version:", transformers.__version__)
from transformers import AutoModelForCausalLM, AutoTokenizer
print("Loading tokenizer (from base) + model (from FT)...")
tok = AutoTokenizer.from_pretrained("openbmb/MiniCPM4.1-8B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cuda"
).eval()
print("has generate:", hasattr(model, "generate"))
print("class mro:", [c.__name__ for c in type(model).__mro__])
prompt = (
"You are a chef. Given ingredients: tomato, onion, garlic, pasta, olive oil.\n"
'Return ONLY JSON: {"options": [{"name": "...", "why": "..."}, ...]} with 3 dish ideas.'
)
messages = [{"role": "user", "content": prompt}]
# Mirror the fixed planner.py path
try:
enc = tok.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_tensors="pt", return_dict=True,
)
input_ids = enc["input_ids"].to("cuda")
input_len = input_ids.shape[1]
gen_inputs = {"input_ids": input_ids}
if enc.get("attention_mask") is not None:
gen_inputs["attention_mask"] = enc["attention_mask"].to("cuda")
print("input length:", input_len)
with torch.no_grad():
out = model.generate(**gen_inputs, max_new_tokens=400, do_sample=False)
text = tok.decode(out[0][input_len:], skip_special_tokens=True)
print("=== GENERATION OK (transformers 4.x, cache on) ===")
print("OUTPUT:", repr(text[:1000]))
except Exception as e:
import traceback
print("=== GENERATION FAILED ===")
print("Exception type:", type(e).__name__)
print("Exception repr:", repr(e))
traceback.print_exc()
@app.local_entrypoint()
def main():
diag.remote()