Scalable_monarch_adapter / src /_eval_drop.py
nvan13's picture
Upload folder using huggingface_hub
ecadbd9 verified
## FB 124M
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
import re
import string
import collections
import numpy as np
import json
from .config import MainConfig, convert_to_trainer_args
from smpeft.sama import SamaConfig #RotationTuner
from smpeft import get_peft_model, PeftModel
import draccus
import random
import transformers
BATCH_SIZE = 32
IGNORE_INDEX=-100
MAX_NEW_TOKENS = 50
PROMPT_TEMPLATE = (
"Below is an passage followed by a coresponding question that describes a task "
"Write a response that appropriately completes the request with your answer.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
)
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
"""
DROP often has multiple valid answer spans.
We take the max score among all valid ground truths.
"""
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
transformers.set_seed(seed)
def generate_batch(model, tokenizer, batch_samples):
prompts = []
PROMPT_TEMPLATE = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
)
for passage, question in zip(batch_samples['passage'], batch_samples['question']):
instr = f"Passage: {passage}\nQuestion: {question}"
prompts.append(PROMPT_TEMPLATE.format(instruction=instr))
# Tokenize
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
).to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
do_sample=False, # Greedy decoding
pad_token_id=tokenizer.pad_token_id,
repetition_penalty=1.2
)
# Truncate input
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs[:, input_length:]
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
final_answers = [text.strip() for text in decoded_preds]
return final_answers
@draccus.wrap()
def main(mainCfg: MainConfig):
print('='*120)
set_seed(mainCfg.seed)
# print(draccus.dump(mainCfg, default_flow_style=False))
model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,device_map="auto",dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(mainCfg.model.model_name, padding_side='left')
if tokenizer.pad_token is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.pad_token = tokenizer.unk_token
print("Set PAD token to UNK token.")
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
print("Set PAD token to EOS token.")
if model is not None:
model.config.pad_token_id = tokenizer.pad_token_id
if model.config.pad_token_id != tokenizer.pad_token_id:
raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
if mainCfg.model.adapter_path is not None:
model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path+"/ft2", is_trainable = True)
model = model.merge_and_unload() # Merge for speed
model.eval()
else:
raise KeyError('wrong adapter path: ', mainCfg.model.adapter_path)
full_drop_test = load_dataset(path=mainCfg.data.path, split='validation')
test_dataset_raw = full_drop_test.select(range(mainCfg.data.total_test_samples))
results = []
total_em = 0
total_f1 = 0
print(f"Starting Inference on {len(test_dataset_raw)} samples...")
BATCH_SIZE = mainCfg.trainer_args.per_device_eval_batch_size
for i in tqdm(range(0, len(test_dataset_raw), BATCH_SIZE)):
batch_indices = range(i, min(i + BATCH_SIZE, len(test_dataset_raw)))
batch_samples = test_dataset_raw.select(batch_indices)
# generate
batch_preds = generate_batch(model, tokenizer, batch_samples)
#
for idx, pred in zip(batch_indices, batch_preds):
original_item = test_dataset_raw[int(idx)]
ground_truths = original_item['answers_spans']['spans']
# --- GRADE ---
em = metric_max_over_ground_truths(exact_match_score, pred, ground_truths)
f1 = metric_max_over_ground_truths(f1_score, pred, ground_truths)
total_em += em
total_f1 += f1
results.append({
"id": original_item["query_id"],
"prediction": pred,
"ground_truths": ground_truths,
"em": em,
"f1": f1
})
# 4. Final Statistics
avg_em = 100.0 * total_em / len(test_dataset_raw)
avg_f1 = 100.0 * total_f1 / len(test_dataset_raw)
print("\n" + "="*30)
print("RESULTS")
print("="*30)
print(f"Total Samples: {len(test_dataset_raw)}")
print(f"Exact Match (EM): {avg_em:.2f}%")
print(f"F1 Score : {avg_f1:.2f}%")
print("="*30)
# 5. Save details to JSON
output_file = mainCfg.model.adapter_path + "/drop_evaluation_results.json"
with open(output_file, "w", encoding='utf-8') as f:
json.dump({
"metrics": {"EM": avg_em, "F1": avg_f1},
"details": results # Sửa tên biến 'predictions' thành 'results'
}, f, indent=2, ensure_ascii=False)
print(f"Detailed results saved to {output_file}")
if __name__ == "__main__":
main()