|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Debug: Check what the fine-tuned model actually outputs""" |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
|
|
|
def main(): |
|
|
print("Loading model...") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Device: {device}") |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") |
|
|
model = PeftModel.from_pretrained(base_model, "Vurtnec/eot-detector-smollm2") |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") |
|
|
|
|
|
print("\nLoading test dataset...") |
|
|
dataset = load_dataset("Vurtnec/eot-detection-testset", split="train") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("DEBUGGING MODEL OUTPUTS") |
|
|
print("=" * 60) |
|
|
|
|
|
for i, sample in enumerate(dataset): |
|
|
if i >= 6: |
|
|
break |
|
|
|
|
|
messages = sample["messages"] |
|
|
is_complete = sample["is_complete"] |
|
|
|
|
|
|
|
|
formatted = "" |
|
|
for msg in messages: |
|
|
formatted += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n" |
|
|
|
|
|
|
|
|
inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=256) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=20, |
|
|
do_sample=False, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
generated = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) |
|
|
|
|
|
print(f"\n--- Sample {i+1} ---") |
|
|
print(f"Ground Truth: {'Complete' if is_complete else 'Incomplete'}") |
|
|
print(f"Last message: {messages[-1]['content'][:50]}...") |
|
|
print(f"Model output: '{generated}'") |
|
|
print(f"Contains 'eot': {'eot' in generated.lower()}") |
|
|
print(f"Contains 'continue': {'continue' in generated.lower()}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|