symbol-fim-model / eval /eval_fim.py
ethanker's picture
Upload eval/eval_fim.py with huggingface_hub
59f0d85 verified
from typing import Dict, Optional
import torch
def eval_fim(model, dataloader, device: torch.device, max_batches: Optional[int] = None) -> Dict[str, float]:
model.eval()
total_loss = 0.0
batches = 0
with torch.no_grad():
for step, batch in enumerate(dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
out = model(
input_ids=batch["input_ids"],
attention_mask=batch.get("attention_mask"),
labels=batch.get("labels"),
)
loss = out.get("lm_loss")
if loss is None:
continue
total_loss += float(loss.item())
batches += 1
if max_batches is not None and (step + 1) >= max_batches:
break
model.train()
if batches == 0:
return {"loss": float("nan")}
return {"loss": total_loss / batches}