Spaces:
Sleeping
Sleeping
| """Model and metric utilities for ReACC-style generation.""" | |
| from __future__ import annotations | |
| from dataset import SPECIAL_TOKENS, build_prompt | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from typing import Dict, Optional, Sequence | |
| from difflib import SequenceMatcher | |
| import math | |
| import os | |
| import sys | |
| CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if CURRENT_DIR not in sys.path: | |
| sys.path.insert(0, CURRENT_DIR) | |
| def load_model_and_tokenizer(model_name_or_path: str): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS}) | |
| model = AutoModelForCausalLM.from_pretrained(model_name_or_path) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| return tokenizer, model | |
| def generate_completion( | |
| model, | |
| tokenizer, | |
| retrieved: str, | |
| context: str, | |
| device: torch.device, | |
| max_length: int = 384, | |
| max_new_tokens: int = 64, | |
| do_sample: bool = False, | |
| temperature: float = 0.2, | |
| top_p: float = 0.95, | |
| stop_strings: Optional[Sequence[str]] = None, | |
| ) -> str: | |
| prompt = build_prompt(retrieved, context) | |
| inputs = tokenizer(prompt, return_tensors="pt", | |
| truncation=True, max_length=max_length) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| generation_kwargs = dict( | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| num_beams=1, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| if do_sample: | |
| generation_kwargs["temperature"] = temperature | |
| generation_kwargs["top_p"] = top_p | |
| output = model.generate(**inputs, **generation_kwargs) | |
| full = tokenizer.decode(output[0], skip_special_tokens=False) | |
| prompt_text = tokenizer.decode( | |
| inputs["input_ids"][0], skip_special_tokens=False) | |
| generated = full[len(prompt_text):] | |
| if stop_strings: | |
| cut = None | |
| for s in stop_strings: | |
| pos = generated.find(s) | |
| if pos >= 0: | |
| cut = pos if cut is None else min(cut, pos) | |
| if cut is not None: | |
| generated = generated[:cut] | |
| return generated | |
| def exact_match(pred: str, gold: str) -> float: | |
| return 1.0 if pred.strip() == gold.strip() else 0.0 | |
| def edit_similarity(pred: str, gold: str) -> float: | |
| return SequenceMatcher(None, pred.strip(), gold.strip()).ratio() * 100.0 | |
| def perplexity_from_loss(loss_value: float) -> float: | |
| if loss_value >= 20: | |
| return float("inf") | |
| return math.exp(loss_value) | |
| def evaluate_generation(preds: Sequence[str], golds: Sequence[str]) -> Dict[str, float]: | |
| assert len(preds) == len(golds) | |
| if not preds: | |
| return {"exact_match": 0.0, "edit_similarity": 0.0} | |
| em = sum(exact_match(p, g) for p, g in zip(preds, golds)) / len(preds) | |
| es = sum(edit_similarity(p, g) for p, g in zip(preds, golds)) / len(preds) | |
| return {"exact_match": em * 100.0, "edit_similarity": es} | |