|
|
|
|
|
"""Evaluate base model vs SCU adapter on BPT and perplexity.""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import math |
|
|
import json |
|
|
import random |
|
|
import statistics as stats |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent)) |
|
|
from scu import data |
|
|
|
|
|
|
|
|
def bpt_for_texts(model, tokenizer, texts, max_len=512, device=None): |
|
|
"""Calculate BPT for each text. |
|
|
|
|
|
Returns list of BPT values (one per text). |
|
|
""" |
|
|
model.eval() |
|
|
bpts = [] |
|
|
|
|
|
for text in texts: |
|
|
|
|
|
enc = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=max_len, |
|
|
padding=False |
|
|
) |
|
|
|
|
|
|
|
|
enc = {k: v.to(device or model.device) for k, v in enc.items()} |
|
|
|
|
|
|
|
|
labels = enc["input_ids"].clone() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**enc, labels=labels) |
|
|
|
|
|
bpt = outputs.loss.item() / math.log(2) |
|
|
bpts.append(bpt) |
|
|
|
|
|
return bpts |
|
|
|
|
|
|
|
|
def bootstrap_ci(delta_list, iters=10000, seed=42): |
|
|
"""Bootstrap confidence interval for mean difference. |
|
|
|
|
|
Returns (lower_95, mean, upper_95) |
|
|
""" |
|
|
random.seed(seed) |
|
|
means = [] |
|
|
n = len(delta_list) |
|
|
|
|
|
for _ in range(iters): |
|
|
|
|
|
sample = [delta_list[random.randrange(n)] for _ in range(n)] |
|
|
means.append(stats.mean(sample)) |
|
|
|
|
|
means.sort() |
|
|
lower = means[int(0.025 * iters)] |
|
|
upper = means[int(0.975 * iters)] |
|
|
mean_val = stats.mean(delta_list) |
|
|
|
|
|
return lower, mean_val, upper |
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
dtype = torch.float16 |
|
|
use_4bit = not args.no_4bit |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = "mps" |
|
|
dtype = torch.float32 |
|
|
use_4bit = False |
|
|
else: |
|
|
device = "cpu" |
|
|
dtype = torch.float32 |
|
|
use_4bit = False |
|
|
print("WARNING: Using CPU - evaluation will be slow") |
|
|
|
|
|
|
|
|
quantization_config = None |
|
|
if use_4bit and device == "cuda": |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=dtype, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_use_double_quant=True |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Loading base model: {args.base_model}") |
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
args.base_model, |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=dtype, |
|
|
device_map="auto" if device != "cpu" else None, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.base_model) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
print(f"Loading validation texts from {args.texts}") |
|
|
val_texts = data.load_texts_from_file(args.texts, max_texts=args.max_texts) |
|
|
print(f"Loaded {len(val_texts)} texts") |
|
|
|
|
|
|
|
|
print("\nEvaluating base model...") |
|
|
base_bpts = bpt_for_texts(base_model, tokenizer, val_texts, max_len=args.max_length, device=device) |
|
|
base_mean_bpt = stats.mean(base_bpts) |
|
|
base_perplexity = 2 ** base_mean_bpt |
|
|
|
|
|
|
|
|
if args.adapter_path: |
|
|
print(f"\nLoading SCU adapter from {args.adapter_path}") |
|
|
scu_model = PeftModel.from_pretrained(base_model, args.adapter_path) |
|
|
scu_model.eval() |
|
|
|
|
|
|
|
|
print("Evaluating SCU model...") |
|
|
scu_bpts = bpt_for_texts(scu_model, tokenizer, val_texts, max_len=args.max_length, device=device) |
|
|
scu_mean_bpt = stats.mean(scu_bpts) |
|
|
scu_perplexity = 2 ** scu_mean_bpt |
|
|
|
|
|
|
|
|
delta_bpts = [b - s for b, s in zip(base_bpts, scu_bpts)] |
|
|
delta_mean = stats.mean(delta_bpts) |
|
|
|
|
|
|
|
|
if args.bootstrap: |
|
|
print("\nCalculating bootstrap confidence interval...") |
|
|
ci_lower, ci_mean, ci_upper = bootstrap_ci(delta_bpts, iters=args.bootstrap_iters) |
|
|
else: |
|
|
ci_lower = ci_mean = ci_upper = delta_mean |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("EVALUATION RESULTS") |
|
|
print("="*60) |
|
|
print(f"Base Model: {base_mean_bpt:.3f} BPT (ppl {base_perplexity:.2f})") |
|
|
print(f"SCU Model: {scu_mean_bpt:.3f} BPT (ppl {scu_perplexity:.2f})") |
|
|
print(f"Improvement: {delta_mean:.3f} BPT ({100*delta_mean/base_mean_bpt:.1f}%)") |
|
|
print(f"Perplexity: -{100*(1 - scu_perplexity/base_perplexity):.1f}%") |
|
|
|
|
|
if args.bootstrap: |
|
|
print(f"\nBootstrap 95% CI: [{ci_lower:.3f}, {ci_upper:.3f}]") |
|
|
if ci_lower > 0: |
|
|
print("✓ CI excludes zero - improvement is statistically significant") |
|
|
else: |
|
|
print("✗ CI includes zero - improvement not statistically significant") |
|
|
|
|
|
|
|
|
if args.output: |
|
|
results = { |
|
|
'base_model': args.base_model, |
|
|
'adapter_path': args.adapter_path, |
|
|
'num_texts': len(val_texts), |
|
|
'base_bpt': base_mean_bpt, |
|
|
'scu_bpt': scu_mean_bpt, |
|
|
'delta_bpt': delta_mean, |
|
|
'delta_bpt_percent': 100 * delta_mean / base_mean_bpt, |
|
|
'base_perplexity': base_perplexity, |
|
|
'scu_perplexity': scu_perplexity, |
|
|
'perplexity_reduction': 100 * (1 - scu_perplexity/base_perplexity), |
|
|
'ci_lower': ci_lower, |
|
|
'ci_mean': ci_mean, |
|
|
'ci_upper': ci_upper, |
|
|
'individual_base_bpts': base_bpts, |
|
|
'individual_scu_bpts': scu_bpts |
|
|
} |
|
|
|
|
|
output_path = Path(args.output) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(output_path, 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
print(f"\nResults saved to {args.output}") |
|
|
|
|
|
else: |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("BASE MODEL RESULTS") |
|
|
print("="*60) |
|
|
print(f"BPT: {base_mean_bpt:.3f}") |
|
|
print(f"Perplexity: {base_perplexity:.2f}") |
|
|
print(f"Texts: {len(val_texts)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Evaluate BPT and perplexity") |
|
|
|
|
|
parser.add_argument("--base_model", default="meta-llama/Llama-3.2-1B", |
|
|
help="Base model name") |
|
|
parser.add_argument("--adapter_path", default=None, |
|
|
help="Path to SCU adapter (optional)") |
|
|
parser.add_argument("--texts", default="data/val.txt", |
|
|
help="Validation texts file") |
|
|
parser.add_argument("--max_texts", type=int, default=None, |
|
|
help="Maximum texts to evaluate") |
|
|
parser.add_argument("--max_length", type=int, default=512, |
|
|
help="Maximum sequence length") |
|
|
parser.add_argument("--no_4bit", action="store_true", |
|
|
help="Disable 4-bit quantization") |
|
|
parser.add_argument("--bootstrap", action="store_true", |
|
|
help="Calculate bootstrap CI") |
|
|
parser.add_argument("--bootstrap_iters", type=int, default=10000, |
|
|
help="Bootstrap iterations") |
|
|
parser.add_argument("--output", default=None, |
|
|
help="Output JSON file for results") |
|
|
|
|
|
args = parser.parse_args() |
|
|
main(args) |