Spaces:
Sleeping
Sleeping
Varun Wadhwa
commited on
Logs
Browse files
app.py
CHANGED
|
@@ -123,13 +123,9 @@ def evaluate_model(model, dataloader, device):
|
|
| 123 |
|
| 124 |
# Disable gradient calculations
|
| 125 |
with torch.no_grad():
|
| 126 |
-
for batch in dataloader:
|
| 127 |
-
print("Sample sequence labels:", batch['labels'][0].tolist()[:20])
|
| 128 |
-
print("Corresponding predictions:", torch.argmax(model(batch['input_ids'].to(device),
|
| 129 |
-
attention_mask=batch['attention_mask'].to(device)).logits, dim=-1)[0].tolist()[:20])
|
| 130 |
-
break
|
| 131 |
for batch in dataloader:
|
| 132 |
input_ids = batch['input_ids'].to(device)
|
|
|
|
| 133 |
attention_mask = batch['attention_mask'].to(device)
|
| 134 |
labels = batch['labels'].to(device).cpu().numpy()
|
| 135 |
|
|
@@ -140,11 +136,15 @@ def evaluate_model(model, dataloader, device):
|
|
| 140 |
# Get predictions
|
| 141 |
preds = torch.argmax(logits, dim=-1).cpu().numpy()
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
# Calculate evaluation metrics
|
| 150 |
print("evaluate_model sizes")
|
|
|
|
| 123 |
|
| 124 |
# Disable gradient calculations
|
| 125 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
for batch in dataloader:
|
| 127 |
input_ids = batch['input_ids'].to(device)
|
| 128 |
+
current_batch_size = input_ids.size(0)
|
| 129 |
attention_mask = batch['attention_mask'].to(device)
|
| 130 |
labels = batch['labels'].to(device).cpu().numpy()
|
| 131 |
|
|
|
|
| 136 |
# Get predictions
|
| 137 |
preds = torch.argmax(logits, dim=-1).cpu().numpy()
|
| 138 |
|
| 139 |
+
# Use attention mask to get valid tokens
|
| 140 |
+
mask = batch['attention_mask'].cpu().numpy().astype(bool)
|
| 141 |
|
| 142 |
+
# Process each sequence in the batch
|
| 143 |
+
for i in range(current_batch_size):
|
| 144 |
+
valid_preds = preds[i][mask[i]].flatten()
|
| 145 |
+
valid_labels = labels[i][mask[i]].flatten()
|
| 146 |
+
all_preds.extend(valid_preds.tolist())
|
| 147 |
+
all_labels.extend(valid_labels.tolist())
|
| 148 |
|
| 149 |
# Calculate evaluation metrics
|
| 150 |
print("evaluate_model sizes")
|