Spaces:
Sleeping
Sleeping
Varun Wadhwa
commited on
Logs
Browse files
app.py
CHANGED
|
@@ -78,31 +78,32 @@ print(raw_dataset.column_names)
|
|
| 78 |
# function to align labels with tokens
|
| 79 |
# --> special tokens: -100 label id (ignored by cross entropy),
|
| 80 |
# --> if tokens are inside a word, replace 'B-' with 'I-'
|
| 81 |
-
def align_labels_with_tokens(labels):
|
| 82 |
aligned_label_ids = []
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
return aligned_label_ids
|
| 90 |
|
| 91 |
# create tokenize function
|
| 92 |
def tokenize_function(examples):
|
| 93 |
-
# tokenize and truncate text. The examples argument would have already stripped
|
| 94 |
-
# the train or test label.
|
| 95 |
-
new_labels = []
|
| 96 |
inputs = tokenizer(
|
| 97 |
examples['mbert_tokens'],
|
| 98 |
is_split_into_words=True,
|
| 99 |
-
padding=True,
|
| 100 |
truncation=True,
|
| 101 |
-
max_length=512
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
inputs["labels"] =
|
|
|
|
|
|
|
|
|
|
| 106 |
return inputs
|
| 107 |
|
| 108 |
# tokenize training and validation datasets
|
|
@@ -111,54 +112,43 @@ tokenized_data = raw_dataset.map(
|
|
| 111 |
batched=True)
|
| 112 |
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 113 |
# data collator
|
| 114 |
-
data_collator = DataCollatorForTokenClassification(
|
|
|
|
|
|
|
| 115 |
|
| 116 |
st.write(tokenized_data["train"][:2]["labels"])
|
| 117 |
|
| 118 |
# Function to evaluate model performance
|
| 119 |
def evaluate_model(model, dataloader, device):
|
| 120 |
-
model.eval()
|
| 121 |
-
all_preds = []
|
| 122 |
-
all_labels = []
|
| 123 |
|
| 124 |
-
# Disable gradient calculations
|
| 125 |
with torch.no_grad():
|
| 126 |
for batch in dataloader:
|
| 127 |
input_ids = batch['input_ids'].to(device)
|
| 128 |
attention_mask = batch['attention_mask'].to(device)
|
| 129 |
-
labels = batch['labels'].to(device)
|
| 130 |
-
|
| 131 |
-
print(x)
|
| 132 |
-
print("OTHERS:")
|
| 133 |
-
for l in labels:
|
| 134 |
-
if len(l) != x:
|
| 135 |
-
print(len(l))
|
| 136 |
-
break
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# Forward pass to get logits
|
| 140 |
outputs = model(input_ids, attention_mask=attention_mask)
|
| 141 |
logits = outputs.logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
# Get predictions
|
| 144 |
-
preds = torch.argmax(logits, dim=-1).cpu().numpy()
|
| 145 |
-
|
| 146 |
-
all_preds.extend(preds)
|
| 147 |
-
all_labels.extend(labels)
|
| 148 |
-
|
| 149 |
-
# Calculate evaluation metrics
|
| 150 |
-
print("evaluate_model sizes")
|
| 151 |
-
print(len(all_preds[0]))
|
| 152 |
-
print(len(all_labels[0]))
|
| 153 |
-
all_preds = np.asarray(all_preds, dtype=np.float32)
|
| 154 |
-
all_labels = np.asarray(all_labels, dtype=np.float32)
|
| 155 |
-
print("Flattened sizes")
|
| 156 |
-
print(all_preds.size)
|
| 157 |
-
print(all_labels.size)
|
| 158 |
-
all_preds = all_preds.flatten()
|
| 159 |
-
all_labels = all_labels.flatten()
|
| 160 |
accuracy = accuracy_score(all_labels, all_preds)
|
| 161 |
-
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
|
|
|
|
|
| 162 |
|
| 163 |
return accuracy, precision, recall, f1
|
| 164 |
|
|
|
|
| 78 |
# function to align labels with tokens
|
| 79 |
# --> special tokens: -100 label id (ignored by cross entropy),
|
| 80 |
# --> if tokens are inside a word, replace 'B-' with 'I-'
|
| 81 |
+
def align_labels_with_tokens(labels, word_ids, max_length):
|
| 82 |
aligned_label_ids = []
|
| 83 |
+
for word_id in word_ids:
|
| 84 |
+
if word_id is None:
|
| 85 |
+
aligned_label_ids.append(-100)
|
| 86 |
+
else:
|
| 87 |
+
aligned_label_ids.append(label2id[labels[word_id]].replace("B-", "I-"))
|
| 88 |
+
|
| 89 |
+
# Pad to max length
|
| 90 |
+
aligned_label_ids += [-100] * (max_length - len(aligned_label_ids))
|
| 91 |
return aligned_label_ids
|
| 92 |
|
| 93 |
# create tokenize function
|
| 94 |
def tokenize_function(examples):
|
|
|
|
|
|
|
|
|
|
| 95 |
inputs = tokenizer(
|
| 96 |
examples['mbert_tokens'],
|
| 97 |
is_split_into_words=True,
|
|
|
|
| 98 |
truncation=True,
|
| 99 |
+
max_length=512,
|
| 100 |
+
padding="max_length"
|
| 101 |
+
)
|
| 102 |
+
word_ids = inputs.word_ids()
|
| 103 |
+
inputs["labels"] = [
|
| 104 |
+
align_labels_with_tokens(labels, word_ids, tokenizer.model_max_length)
|
| 105 |
+
for labels in examples['mbert_token_classes']
|
| 106 |
+
]
|
| 107 |
return inputs
|
| 108 |
|
| 109 |
# tokenize training and validation datasets
|
|
|
|
| 112 |
batched=True)
|
| 113 |
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 114 |
# data collator
|
| 115 |
+
data_collator = DataCollatorForTokenClassification(
|
| 116 |
+
tokenizer, padding=True, truncation=True, max_length=512
|
| 117 |
+
)
|
| 118 |
|
| 119 |
st.write(tokenized_data["train"][:2]["labels"])
|
| 120 |
|
| 121 |
# Function to evaluate model performance
|
| 122 |
def evaluate_model(model, dataloader, device):
|
| 123 |
+
model.eval()
|
| 124 |
+
all_preds, all_labels = [], []
|
|
|
|
| 125 |
|
|
|
|
| 126 |
with torch.no_grad():
|
| 127 |
for batch in dataloader:
|
| 128 |
input_ids = batch['input_ids'].to(device)
|
| 129 |
attention_mask = batch['attention_mask'].to(device)
|
| 130 |
+
labels = batch['labels'].to(device)
|
| 131 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
outputs = model(input_ids, attention_mask=attention_mask)
|
| 133 |
logits = outputs.logits
|
| 134 |
+
preds = torch.argmax(logits, dim=-1)
|
| 135 |
+
|
| 136 |
+
# Mask out padding tokens (-100 in labels)
|
| 137 |
+
mask = labels != -100
|
| 138 |
+
valid_preds = preds[mask]
|
| 139 |
+
valid_labels = labels[mask]
|
| 140 |
+
|
| 141 |
+
all_preds.extend(valid_preds.cpu().numpy())
|
| 142 |
+
all_labels.extend(valid_labels.cpu().numpy())
|
| 143 |
+
|
| 144 |
+
# Convert to numpy arrays for metrics calculation
|
| 145 |
+
all_preds = np.array(all_preds)
|
| 146 |
+
all_labels = np.array(all_labels)
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
accuracy = accuracy_score(all_labels, all_preds)
|
| 149 |
+
precision, recall, f1, _ = precision_recall_fscore_support(
|
| 150 |
+
all_labels, all_preds, average='micro'
|
| 151 |
+
)
|
| 152 |
|
| 153 |
return accuracy, precision, recall, f1
|
| 154 |
|