Spaces:
No application file
No application file
| from transformers import LlamaForCausalLM, GPT2LMHeadModel, AutoTokenizer | |
| import torch | |
| import math | |
| from collections import Counter | |
| import argparse | |
| from torch.cuda.amp import autocast | |
| import torch.nn.functional as F | |
| import json | |
| import matplotlib.pyplot as plt | |
| import os | |
| from data import get_text_dataset | |
| from data.text import setup_tokeniser_from_dataset | |
| llama_model_path = "meta-llama/Llama-2-7b-hf" | |
| gpt2_model_path = "gpt2-large" | |
| def get_reference_text_dataset(): | |
| dataset = get_text_dataset( | |
| "openwebtext", | |
| split="train", | |
| max_length=1024, | |
| filter_max_length=False | |
| )[:5000]["input_ids"] | |
| tokeniser = setup_tokeniser_from_dataset("openwebtext") | |
| return tokeniser.batch_decode(dataset, skip_special_tokens=True) | |
| def batch_reduce(batch, func, reduce_fn, init, step=16): | |
| """ | |
| Function signature: Tensor[B, L] -> func:(Tensor[B', L] -> A) -> reduce_fn:(B -> A -> B) -> init:B' -> steps:int -> B | |
| """ | |
| result = init | |
| for i in range(0, len(batch), step): | |
| sub_batch = batch[i : min(i + step, len(batch))] | |
| sub_result = func(sub_batch) | |
| result = reduce_fn(result, sub_result) | |
| return result | |
| def compute_generative_perplexity( | |
| text_samples, max_length: int = 1024, retokenize: bool = True, | |
| input_is_tokenized: bool = False, tokenizer=None, model_type="llama" | |
| ) -> None: | |
| # load the specified model based on model_type | |
| if model_type == "llama": | |
| eval_model = LlamaForCausalLM.from_pretrained( | |
| llama_model_path, | |
| torch_dtype=torch.float16, | |
| ).eval() | |
| model_path = llama_model_path | |
| elif model_type == "gpt2-xl": | |
| eval_model = GPT2LMHeadModel.from_pretrained( | |
| gpt2_model_path, | |
| torch_dtype=torch.float16, | |
| ).eval() | |
| model_path = gpt2_model_path | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| eval_model = eval_model.to("cuda") | |
| if tokenizer is None: | |
| eval_model_tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| eval_model_tokenizer.pad_token = eval_model_tokenizer.eos_token | |
| else: | |
| eval_model_tokenizer = tokenizer | |
| # tokenize the batch or use pre-tokenized input | |
| if input_is_tokenized: | |
| # If input is already token IDs, create the tensor and pad if necessary | |
| max_len = max(len(seq) for seq in text_samples) | |
| padded_max_len = min(max_len, max_length) | |
| input_ids = torch.ones((len(text_samples), padded_max_len), | |
| dtype=torch.long) * eval_model_tokenizer.pad_token_id | |
| for i, seq in enumerate(text_samples): | |
| seq_len = min(len(seq), padded_max_len) | |
| input_ids[i, :seq_len] = torch.tensor(seq[:seq_len]) | |
| input_ids = input_ids.to(eval_model.device) | |
| print(input_ids) | |
| else: | |
| # tokenize the text samples | |
| tokenized = eval_model_tokenizer( | |
| text_samples, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=max_length, | |
| ).to(eval_model.device) | |
| input_ids = tokenized["input_ids"] | |
| eos_token_id = eval_model_tokenizer.eos_token_id | |
| eos_mask = input_ids == eos_token_id | |
| first_eos = eos_mask.cumsum(dim=-1) == 1 | |
| # generative perplexity | |
| with autocast(), torch.no_grad(): | |
| outputs = eval_model(input_ids) | |
| logits = outputs.logits | |
| logits = logits.transpose( | |
| -1, -2 | |
| ) # size b X D X N, D = the number of possible tokens | |
| nlls = F.cross_entropy(logits[..., :-1], input_ids[..., 1:], reduction="none") | |
| effective_mask = (first_eos[..., 1:] + (input_ids[..., 1:] != eos_token_id)).bool() | |
| nlls = nlls * effective_mask | |
| # compute per-sample perplexity | |
| likelihood_list = [] | |
| for b in range(input_ids.size(0)): | |
| nll = nlls[b] | |
| mask = effective_mask[b] | |
| likelihood = nll.sum() / mask.sum() | |
| likelihood_list.append(likelihood.exp().item()) | |
| return likelihood_list | |
| def compute_entropy(samples: list, model_name: str = llama_model_path, | |
| input_is_tokenized: bool = False, tokenizer=None): | |
| """ | |
| Compute the entropy of each text sample using subword tokens. | |
| Can accept either text samples or pre-tokenized token IDs. | |
| """ | |
| # initialize tokenizer if not provided | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| # use provided token IDs or encode text samples | |
| if input_is_tokenized: | |
| token_id_seqs = samples | |
| else: | |
| # encode each sample into subword IDs (no special tokens) | |
| token_id_seqs = [ | |
| tokenizer.encode(sample, add_special_tokens=False) for sample in samples | |
| ] | |
| # compute per-sample entropy | |
| entropies = [] | |
| for seq in token_id_seqs: | |
| counts = Counter(seq) | |
| total = sum(counts.values()) | |
| entropy = ( | |
| -sum((cnt / total) * math.log(cnt / total, 2) for cnt in counts.values()) | |
| if total > 0 | |
| else 0.0 | |
| ) | |
| entropies.append(entropy) | |
| return entropies | |
| def compute_mauve_score(candidate_samples, reference_samples): | |
| import mauve | |
| score = mauve.compute_mauve(p_text=candidate_samples, q_text=reference_samples, device_id=0, max_text_length=1024, verbose=False) | |
| return score.mauve | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Compute average entropy, generative perplexity, and mauve score for a list of text samples." | |
| ) | |
| parser.add_argument( | |
| "--input-json", | |
| type=str, | |
| help="Path to a JSON file containing a list of strings", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=16, | |
| help="Batch size for computing generative perplexity", | |
| ) | |
| parser.add_argument( | |
| "--length-plot-output", | |
| type=str, | |
| default="length_distribution.png", | |
| help="Output path for the sentence length distribution plot", | |
| ) | |
| parser.add_argument( | |
| "--perplexity-plot-output", | |
| type=str, | |
| default=None, # Will be derived from length-plot-output | |
| help="Output path for the perplexity vs length scatter plot", | |
| ) | |
| parser.add_argument( | |
| "--results-output", | |
| type=str, | |
| default=None, | |
| help="Path to JSON file to save computed metrics", | |
| ) | |
| parser.add_argument( | |
| "--eval-mode", | |
| type=str, | |
| choices=["sentence", "chunk"], | |
| default="sentence", | |
| help="sentence: eval each input as one; chunk: tokenize & split into 1024‐length segments", | |
| ) | |
| parser.add_argument( | |
| "--model-type", | |
| type=str, | |
| choices=["llama", "gpt2-large"], | |
| default="llama", | |
| help="Model to use for generative perplexity evaluation", | |
| ) | |
| # New flags to control metric evaluation (default false) | |
| parser.add_argument("--entropy", action="store_true", default=False, help="Evaluate entropy") | |
| parser.add_argument("--perplexity", action="store_true", default=False, help="Evaluate generative perplexity") | |
| parser.add_argument("--mauve", action="store_true", default=False, help="Evaluate mauve score") | |
| parser.add_argument("--reference-perplexity", action="store_true", default=False, help="Evaluate reference text perplexity") | |
| args = parser.parse_args() | |
| # Derive perplexity plot path from length plot path if not specified | |
| if args.perplexity_plot_output is None: | |
| base, ext = os.path.splitext(args.length_plot_output) | |
| args.perplexity_plot_output = f"{base}_perplexity{ext}" | |
| with open(args.input_json, "r", encoding="utf-8") as f: | |
| samples = json.load(f) | |
| # choose sentence‐level or chunk‐level inputs | |
| if args.eval_mode == "chunk": | |
| # pre‐load tokenizer based on model type | |
| model_path = llama_model_path if args.model_type == "llama" else gpt2_model_path | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| chunk_size = 1024 | |
| # Tokenize all samples | |
| token_id_seqs = [tokenizer.encode(s, add_special_tokens=False) for s in samples] | |
| # Concatenate all sentences with EOS tokens between them | |
| concatenated_tokens = [] | |
| for seq in token_id_seqs: | |
| concatenated_tokens.extend(seq) | |
| concatenated_tokens.append(tokenizer.eos_token_id) # Add EOS between sentences | |
| # Truncate concatenated_tokens to be a multiple of chunk_size | |
| truncated_length = (len(concatenated_tokens) // chunk_size) * chunk_size | |
| concatenated_tokens = concatenated_tokens[:truncated_length] | |
| # Split the concatenated tokens into chunks of size chunk_size | |
| chunks = [] | |
| for i in range(0, len(concatenated_tokens), chunk_size): | |
| chunks.append(concatenated_tokens[i:i + chunk_size]) | |
| # Keep chunks as token IDs for direct use | |
| target_samples = chunks | |
| use_tokenized_input = True | |
| else: | |
| target_samples = samples | |
| use_tokenized_input = False | |
| # Conditionally compute entropy | |
| if args.entropy: | |
| entropy_list = compute_entropy( | |
| target_samples, | |
| input_is_tokenized=use_tokenized_input, | |
| tokenizer=tokenizer if use_tokenized_input else None | |
| ) | |
| avg_entropy = sum(entropy_list) / len(entropy_list) | |
| print(f"Average entropy: {avg_entropy:.4f}") | |
| else: | |
| avg_entropy = None | |
| print("Entropy evaluation skipped") | |
| # Conditionally compute generative perplexity | |
| if args.perplexity: | |
| all_perps = batch_reduce( | |
| target_samples, | |
| lambda batch: compute_generative_perplexity( | |
| batch, | |
| input_is_tokenized=use_tokenized_input, | |
| tokenizer=tokenizer if use_tokenized_input else None, | |
| model_type=args.model_type | |
| ), | |
| lambda acc, res: acc + res, | |
| init=[], | |
| step=args.batch_size, | |
| ) | |
| avg_perp = sum(all_perps) / len(all_perps) | |
| print(f"Average generative perplexity: {avg_perp:.4f}") | |
| else: | |
| avg_perp = None | |
| all_perps = None | |
| print("Generative perplexity evaluation skipped") | |
| # Conditionally compute reference text perplexity | |
| if args.reference_perplexity: | |
| print("Computing reference text perplexity...") | |
| reference_samples = get_reference_text_dataset() | |
| reference_perps = batch_reduce( | |
| reference_samples, | |
| lambda batch: compute_generative_perplexity( | |
| batch, | |
| input_is_tokenized=False, | |
| tokenizer=None, | |
| model_type=args.model_type | |
| ), | |
| lambda acc, res: acc + res, | |
| init=[], | |
| step=args.batch_size, | |
| ) | |
| avg_reference_perp = sum(reference_perps) / len(reference_perps) | |
| print(f"Average reference perplexity: {avg_reference_perp:.4f}") | |
| else: | |
| avg_reference_perp = None | |
| reference_perps = None | |
| reference_samples = None | |
| print("Reference perplexity evaluation skipped") | |
| # Conditionally compute mauve score | |
| if args.mauve: | |
| if reference_samples is None: | |
| reference_samples = get_reference_text_dataset() | |
| mauve_score = compute_mauve_score(samples, reference_samples) | |
| print(f"Mauve score: {mauve_score:.4f}") | |
| else: | |
| mauve_score = None | |
| print("Mauve evaluation skipped") | |
| # Calculate lengths early for use in filtered perplexities | |
| gpt2_tokenizer = AutoTokenizer.from_pretrained(gpt2_model_path) | |
| lengths = [len(gpt2_tokenizer.encode(s, add_special_tokens=False)) for s in samples] | |
| # Conditionally create perplexity vs. tokenized length plot when perplexity is evaluated | |
| filtered_perplexities = None | |
| reference_filtered_perplexities = None | |
| if args.perplexity and args.eval_mode == "sentence": | |
| idx = [] | |
| val = [] | |
| for i in range(0, 1024): | |
| _val = [] | |
| for l, perp in zip(lengths, all_perps): | |
| if l >= i: | |
| _val.append(perp) | |
| idx.append(i) | |
| val.append(sum(_val) / len(_val) if _val else 0) | |
| # Store filtered perplexities for JSON output | |
| filtered_perplexities = { | |
| "token_thresholds": idx, | |
| "avg_perplexities": val | |
| } | |
| plt.figure(figsize=(12, 6)) | |
| # Plot candidate samples | |
| plt.scatter(idx, val, alpha=0.6, color="blue", label="Candidate samples") | |
| # Plot reference samples if available | |
| if args.reference_perplexity and reference_samples is not None: | |
| reference_lengths = [len(gpt2_tokenizer.encode(s, add_special_tokens=False)) for s in reference_samples] | |
| ref_idx = [] | |
| ref_val = [] | |
| for i in range(0, 1024): | |
| _ref_val = [] | |
| for l, perp in zip(reference_lengths, reference_perps): | |
| if l >= i: | |
| _ref_val.append(perp) | |
| ref_idx.append(i) | |
| ref_val.append(sum(_ref_val) / len(_ref_val) if _ref_val else 0) | |
| # Store reference filtered perplexities for JSON output | |
| reference_filtered_perplexities = { | |
| "token_thresholds": ref_idx, | |
| "avg_perplexities": ref_val | |
| } | |
| plt.scatter(ref_idx, ref_val, alpha=0.6, color="red", label="Reference samples") | |
| # Add horizontal lines for specific token lengths | |
| for tlen in [10, 20, 30, 40, 50, 75, 100]: | |
| if tlen < len(val): | |
| plt.axhline(y=val[tlen], linestyle='--', color='blue', alpha=0.3) | |
| plt.title("Perplexity vs. Tokenized Length") | |
| plt.xlabel("Number of tokens") | |
| plt.ylabel("Log Perplexity") | |
| plt.legend() | |
| ax = plt.gca() | |
| ticks = list(ax.get_yticks()) | |
| for tlen in [10, 20, 30, 40, 50, 75, 100]: | |
| if tlen < len(val): | |
| tick_value = val[tlen] | |
| if tick_value not in ticks: | |
| ticks.append(tick_value) | |
| ax.set_yticks(sorted(ticks)) | |
| import matplotlib.ticker as ticker | |
| ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f')) | |
| plt.yscale("log") | |
| plt.grid(True, linestyle='--', alpha=0.7) | |
| plt.tight_layout() | |
| plt.savefig(args.perplexity_plot_output) | |
| print(f"Saved perplexity vs. length scatter plot to {args.perplexity_plot_output}") | |
| elif args.eval_mode == "sentence": | |
| print("Perplexity plot skipped because the --perplexity flag was not provided") | |
| if args.results_output: | |
| results = { | |
| "avg_entropy": avg_entropy, | |
| "avg_perplexity": avg_perp, | |
| "avg_reference_perplexity": avg_reference_perp, | |
| "mauve_score": mauve_score, | |
| "filtered_perplexities": filtered_perplexities, | |
| "reference_filtered_perplexities": reference_filtered_perplexities | |
| } | |
| with open(args.results_output, "w", encoding="utf-8") as outf: | |
| json.dump(results, outf, indent=2) | |
| print(f"Saved metrics to {args.results_output}") | |
| # plot cumulative distribution of GPT2‐tokenized sentence lengths | |
| # Create cumulative distribution | |
| sorted_lengths = sorted(lengths) | |
| cumulative_percentages = [i / len(sorted_lengths) * 100 for i in range(1, len(sorted_lengths) + 1)] | |
| # Save length data to JSON file | |
| length_data = { | |
| "lengths": lengths, | |
| "sorted_lengths": sorted_lengths, | |
| "cumulative_percentages": cumulative_percentages, | |
| "num_samples": len(samples) | |
| } | |
| base, ext = os.path.splitext(args.length_plot_output) | |
| length_data_output = f"{base}.json" | |
| with open(length_data_output, "w", encoding="utf-8") as f: | |
| json.dump(length_data, f, indent=2) | |
| print(f"Saved length distribution data to {length_data_output}") | |
| plt.figure() | |
| plt.plot(sorted_lengths, cumulative_percentages, color="skyblue", linewidth=2) | |
| plt.title("Tokenized Sentence Length Cumulative Distribution") | |
| plt.xlabel("Number of tokens") | |
| plt.ylabel("Cumulative percentage (%)") | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(args.length_plot_output) | |
| if args.eval_mode == "chunk": | |
| print(f"Evaluated in chunk mode over {len(target_samples)} segments (using pre-tokenized input)") | |
| else: | |
| print(f"Evaluated in sentence mode over {len(target_samples)} samples") | |
| if args.perplexity: | |
| print(f"Saved perplexity vs. length scatter plot to {args.perplexity_plot_output}") | |
| if __name__ == "__main__": | |
| main() |