from typing import Dict, Optional import torch def eval_fim(model, dataloader, device: torch.device, max_batches: Optional[int] = None) -> Dict[str, float]: model.eval() total_loss = 0.0 batches = 0 with torch.no_grad(): for step, batch in enumerate(dataloader): batch = {k: v.to(device) for k, v in batch.items()} out = model( input_ids=batch["input_ids"], attention_mask=batch.get("attention_mask"), labels=batch.get("labels"), ) loss = out.get("lm_loss") if loss is None: continue total_loss += float(loss.item()) batches += 1 if max_batches is not None and (step + 1) >= max_batches: break model.train() if batches == 0: return {"loss": float("nan")} return {"loss": total_loss / batches}