| from typing import Dict, Optional | |
| import torch | |
| def eval_fim(model, dataloader, device: torch.device, max_batches: Optional[int] = None) -> Dict[str, float]: | |
| model.eval() | |
| total_loss = 0.0 | |
| batches = 0 | |
| with torch.no_grad(): | |
| for step, batch in enumerate(dataloader): | |
| batch = {k: v.to(device) for k, v in batch.items()} | |
| out = model( | |
| input_ids=batch["input_ids"], | |
| attention_mask=batch.get("attention_mask"), | |
| labels=batch.get("labels"), | |
| ) | |
| loss = out.get("lm_loss") | |
| if loss is None: | |
| continue | |
| total_loss += float(loss.item()) | |
| batches += 1 | |
| if max_batches is not None and (step + 1) >= max_batches: | |
| break | |
| model.train() | |
| if batches == 0: | |
| return {"loss": float("nan")} | |
| return {"loss": total_loss / batches} | |