ethanker commited on
Commit
59f0d85
·
verified ·
1 Parent(s): da64215

Upload eval/eval_fim.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval/eval_fim.py +33 -0
eval/eval_fim.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+
5
+
6
+ def eval_fim(model, dataloader, device: torch.device, max_batches: Optional[int] = None) -> Dict[str, float]:
7
+ model.eval()
8
+ total_loss = 0.0
9
+ batches = 0
10
+
11
+ with torch.no_grad():
12
+ for step, batch in enumerate(dataloader):
13
+ batch = {k: v.to(device) for k, v in batch.items()}
14
+ out = model(
15
+ input_ids=batch["input_ids"],
16
+ attention_mask=batch.get("attention_mask"),
17
+ labels=batch.get("labels"),
18
+ )
19
+ loss = out.get("lm_loss")
20
+ if loss is None:
21
+ continue
22
+ total_loss += float(loss.item())
23
+ batches += 1
24
+ if max_batches is not None and (step + 1) >= max_batches:
25
+ break
26
+
27
+ model.train()
28
+
29
+ if batches == 0:
30
+ return {"loss": float("nan")}
31
+
32
+ return {"loss": total_loss / batches}
33
+