eot-detector-smollm2 / debug_model_output.py
Vurtnec's picture
Upload debug_model_output.py with huggingface_hub
01a0b9f verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets>=3.1",
# "transformers>=4.46",
# "torch>=2.5",
# "peft>=0.13",
# ]
# ///
"""Debug: Check what the fine-tuned model actually outputs"""
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
def main():
print("Loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
base_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
model = PeftModel.from_pretrained(base_model, "Vurtnec/eot-detector-smollm2")
model = model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
print("\nLoading test dataset...")
dataset = load_dataset("Vurtnec/eot-detection-testset", split="train")
print("\n" + "=" * 60)
print("DEBUGGING MODEL OUTPUTS")
print("=" * 60)
for i, sample in enumerate(dataset):
if i >= 6: # Only check first 6 samples
break
messages = sample["messages"]
is_complete = sample["is_complete"]
# Format input
formatted = ""
for msg in messages:
formatted += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
# Run inference
inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=256)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
print(f"\n--- Sample {i+1} ---")
print(f"Ground Truth: {'Complete' if is_complete else 'Incomplete'}")
print(f"Last message: {messages[-1]['content'][:50]}...")
print(f"Model output: '{generated}'")
print(f"Contains 'eot': {'eot' in generated.lower()}")
print(f"Contains 'continue': {'continue' in generated.lower()}")
if __name__ == "__main__":
main()