Spaces:
Paused
Paused
| from __future__ import annotations | |
| import argparse | |
| from dataclasses import replace | |
| from datetime import datetime, timezone | |
| import json | |
| from pathlib import Path | |
| import sys | |
| from typing import Any | |
| import re | |
| import torch | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from src.model.config import TOY_CONFIG | |
| from src.model.qwen_anchor_overlay import QwenAnchorOverlay | |
| DEFAULT_PROMPT = ( | |
| "You are a vegan chef. Write a detailed weekly meal plan with recipes for each day." | |
| ) | |
| DEFAULT_POSITIVE_KEYWORDS = [ | |
| "vegan", | |
| "plant-based", | |
| "tofu", | |
| "lentil", | |
| "lentils", | |
| "chickpea", | |
| "chickpeas", | |
| "bean", | |
| "beans", | |
| "vegetable", | |
| "vegetables", | |
| "mushroom", | |
| "mushrooms", | |
| "dairy-free", | |
| ] | |
| DEFAULT_NEGATIVE_KEYWORDS = [ | |
| "meat", | |
| "chicken", | |
| "beef", | |
| "pork", | |
| "bacon", | |
| "fish", | |
| "salmon", | |
| "tuna", | |
| "shrimp", | |
| "egg", | |
| "eggs", | |
| "cheese", | |
| "butter", | |
| "milk", | |
| "cream", | |
| "yogurt", | |
| "sausage", | |
| "ham", | |
| ] | |
| def _split_csv(value: str) -> list[str]: | |
| return [item.strip() for item in value.split(",") if item.strip()] | |
| def generate_base( | |
| overlay: QwenAnchorOverlay, | |
| prompt: str, | |
| max_new_tokens: int, | |
| max_length: int, | |
| ) -> dict[str, Any]: | |
| tokenizer = overlay.tokenizer | |
| if tokenizer is None: | |
| raise ValueError("tokenizer is required") | |
| device = next(overlay.parameters()).device | |
| encoded = tokenizer( | |
| [prompt], | |
| padding=True, | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoded["input_ids"].to(device) | |
| attention_mask = encoded["attention_mask"].to(device) | |
| generated = input_ids | |
| generated_mask = attention_mask | |
| steps: list[dict[str, Any]] = [] | |
| for _ in range(max_new_tokens): | |
| outputs = overlay.base_model( | |
| input_ids=generated, | |
| attention_mask=generated_mask, | |
| return_dict=True, | |
| ) | |
| next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True) | |
| generated = torch.cat([generated, next_token], dim=1) | |
| generated_mask = torch.cat( | |
| [generated_mask, torch.ones((1, 1), device=device, dtype=generated_mask.dtype)], | |
| dim=1, | |
| ) | |
| token_id = int(next_token.item()) | |
| steps.append( | |
| { | |
| "token_id": token_id, | |
| "token_text": tokenizer.decode([token_id], skip_special_tokens=False), | |
| } | |
| ) | |
| if token_id == int(getattr(tokenizer, "eos_token_id", -1)): | |
| break | |
| if generated.size(1) >= max_length: | |
| break | |
| continuation_ids = generated[0, input_ids.size(1) :] | |
| continuation_text = tokenizer.decode(continuation_ids, skip_special_tokens=True) | |
| return { | |
| "prompt": prompt, | |
| "generated_text": tokenizer.decode(generated[0], skip_special_tokens=True), | |
| "continuation_text": continuation_text, | |
| "steps": steps, | |
| } | |
| def analyze_keywords( | |
| text: str, | |
| positive_keywords: list[str], | |
| negative_keywords: list[str], | |
| ) -> dict[str, Any]: | |
| lowered = text.lower() | |
| positive_hits = {token: lowered.count(token) for token in positive_keywords if token in lowered} | |
| protected_negative_hits: dict[str, int] = {} | |
| negative_hits: dict[str, int] = {} | |
| for token in negative_keywords: | |
| if token not in lowered: | |
| continue | |
| total = lowered.count(token) | |
| protected = 0 | |
| start = 0 | |
| while True: | |
| idx = lowered.find(token, start) | |
| if idx < 0: | |
| break | |
| prefix = lowered[max(0, idx - 12) : idx] | |
| if "vegan " in prefix or "plant-based " in prefix: | |
| protected += 1 | |
| start = idx + len(token) | |
| effective = total - protected | |
| if effective > 0: | |
| negative_hits[token] = effective | |
| if protected > 0: | |
| protected_negative_hits[token] = protected | |
| first_negative = None | |
| for token in negative_keywords: | |
| idx = lowered.find(token) | |
| if idx >= 0 and (first_negative is None or idx < first_negative["char_index"]): | |
| first_negative = {"token": token, "char_index": idx} | |
| first_positive = None | |
| for token in positive_keywords: | |
| idx = lowered.find(token) | |
| if idx >= 0 and (first_positive is None or idx < first_positive["char_index"]): | |
| first_positive = {"token": token, "char_index": idx} | |
| lexical_score = float(sum(positive_hits.values()) - sum(negative_hits.values())) | |
| word_tokens = re.findall(r"[a-zA-Z_]+", lowered) | |
| bigrams = list(zip(word_tokens, word_tokens[1:])) | |
| repeated_bigram_ratio = 0.0 | |
| if bigrams: | |
| repeated_bigram_ratio = 1.0 - (len(set(bigrams)) / max(len(bigrams), 1)) | |
| sentence_candidates = [ | |
| sentence.strip() | |
| for sentence in re.split(r"[.!?\n]+", lowered) | |
| if sentence.strip() | |
| ] | |
| sentence_counts = {} | |
| max_sentence_repeat = 0 | |
| for sentence in sentence_candidates: | |
| sentence_counts[sentence] = sentence_counts.get(sentence, 0) + 1 | |
| max_sentence_repeat = max(max_sentence_repeat, sentence_counts[sentence]) | |
| degeneracy_penalty = float(8.0 * repeated_bigram_ratio + 1.5 * max(0, max_sentence_repeat - 1)) | |
| quality_score = float(lexical_score - degeneracy_penalty - 1.5 * sum(negative_hits.values())) | |
| return { | |
| "positive_hits": positive_hits, | |
| "negative_hits": negative_hits, | |
| "protected_negative_hits": protected_negative_hits, | |
| "positive_total": int(sum(positive_hits.values())), | |
| "negative_total": int(sum(negative_hits.values())), | |
| "first_positive": first_positive, | |
| "first_negative": first_negative, | |
| "lexical_score": lexical_score, | |
| "repeated_bigram_ratio": float(repeated_bigram_ratio), | |
| "max_sentence_repeat": int(max_sentence_repeat), | |
| "degeneracy_penalty": float(degeneracy_penalty), | |
| "quality_score": float(quality_score), | |
| } | |
| def build_markdown_report( | |
| *, | |
| model_name: str, | |
| device: str, | |
| prompt: str, | |
| max_new_tokens: int, | |
| max_length: int, | |
| conflict_threshold: float, | |
| bias_scale: float, | |
| repetition_penalty: float, | |
| frequency_penalty: float, | |
| no_repeat_ngram_size: int, | |
| max_bias_gate_sum: float, | |
| entropy_top_k: int, | |
| entropy_threshold: float, | |
| entropy_slope: float, | |
| pressure_threshold: float, | |
| pressure_slope: float, | |
| pressure_rescue_floor: float, | |
| base: dict[str, Any], | |
| anchor: dict[str, Any], | |
| base_analysis: dict[str, Any], | |
| anchor_analysis: dict[str, Any], | |
| ) -> str: | |
| active_bias_steps = sum( | |
| 1 for step in anchor["steps"] if step.get("bias_nonzero_anchors", 0) > 0 | |
| ) | |
| lines = [ | |
| "# Qwen Long Retention Compare", | |
| "", | |
| f"Date: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M timezone.utc')}", | |
| f"Model: `{model_name}`", | |
| f"Device: `{device}`", | |
| f"Max new tokens: `{max_new_tokens}`", | |
| f"Max length: `{max_length}`", | |
| f"Conflict threshold: `{conflict_threshold:.2f}`", | |
| f"Bias scale: `{bias_scale:.2f}`", | |
| f"Repetition penalty: `{repetition_penalty:.2f}`", | |
| f"Frequency penalty: `{frequency_penalty:.2f}`", | |
| f"No-repeat ngram size: `{no_repeat_ngram_size}`", | |
| f"Max bias gate sum: `{max_bias_gate_sum:.2f}`", | |
| f"Entropy top-k: `{entropy_top_k}`", | |
| f"Entropy threshold: `{entropy_threshold:.2f}`", | |
| f"Entropy slope: `{entropy_slope:.2f}`", | |
| f"Pressure threshold: `{pressure_threshold:.2f}`", | |
| f"Pressure slope: `{pressure_slope:.2f}`", | |
| f"Pressure rescue floor: `{pressure_rescue_floor:.2f}`", | |
| "", | |
| "## Prompt", | |
| "", | |
| f"> {prompt}", | |
| "", | |
| "## Summary", | |
| "", | |
| f"- Base lexical score: `{base_analysis['lexical_score']:.2f}`", | |
| f"- Anchor lexical score: `{anchor_analysis['lexical_score']:.2f}`", | |
| f"- Base quality score: `{base_analysis['quality_score']:.2f}`", | |
| f"- Anchor quality score: `{anchor_analysis['quality_score']:.2f}`", | |
| f"- Base positive hits: `{base_analysis['positive_total']}`", | |
| f"- Anchor positive hits: `{anchor_analysis['positive_total']}`", | |
| f"- Base negative hits: `{base_analysis['negative_total']}`", | |
| f"- Anchor negative hits: `{anchor_analysis['negative_total']}`", | |
| f"- Anchor protected negative hits: `{sum(anchor_analysis['protected_negative_hits'].values())}`", | |
| f"- Base degeneracy penalty: `{base_analysis['degeneracy_penalty']:.2f}`", | |
| f"- Anchor degeneracy penalty: `{anchor_analysis['degeneracy_penalty']:.2f}`", | |
| f"- Anchor bias active steps: `{active_bias_steps}`", | |
| f"- Continuations identical: `{'yes' if base['continuation_text'] == anchor['continuation_text'] else 'no'}`", | |
| "", | |
| "## First keyword events", | |
| "", | |
| f"- Base first positive: `{base_analysis['first_positive']}`", | |
| f"- Base first negative: `{base_analysis['first_negative']}`", | |
| f"- Anchor first positive: `{anchor_analysis['first_positive']}`", | |
| f"- Anchor first negative: `{anchor_analysis['first_negative']}`", | |
| "", | |
| "## Base continuation", | |
| "", | |
| base["continuation_text"].strip() or "_empty_", | |
| "", | |
| "## Anchor-biased continuation", | |
| "", | |
| anchor["continuation_text"].strip() or "_empty_", | |
| "", | |
| ] | |
| return "\n".join(lines) + "\n" | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run a long base vs anchor-biased Qwen compare.") | |
| parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B") | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT) | |
| parser.add_argument("--max_length", type=int, default=1024) | |
| parser.add_argument("--max_new_tokens", type=int, default=500) | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument("--conflict_threshold", type=float, default=0.55) | |
| parser.add_argument("--bias_scale", type=float, default=1.50) | |
| parser.add_argument("--repetition_penalty", type=float, default=1.15) | |
| parser.add_argument("--frequency_penalty", type=float, default=0.05) | |
| parser.add_argument("--no_repeat_ngram_size", type=int, default=3) | |
| parser.add_argument("--max_bias_gate_sum", type=float, default=1.50) | |
| parser.add_argument("--entropy_top_k", type=int, default=32) | |
| parser.add_argument("--entropy_threshold", type=float, default=0.35) | |
| parser.add_argument("--entropy_slope", type=float, default=0.08) | |
| parser.add_argument("--pressure_threshold", "--min_bias_pressure", dest="pressure_threshold", type=float, default=0.60) | |
| parser.add_argument("--pressure_slope", type=float, default=0.08) | |
| parser.add_argument("--pressure_rescue_floor", type=float, default=0.20) | |
| parser.add_argument( | |
| "--positive_keywords", | |
| type=str, | |
| default=",".join(DEFAULT_POSITIVE_KEYWORDS), | |
| ) | |
| parser.add_argument( | |
| "--negative_keywords", | |
| type=str, | |
| default=",".join(DEFAULT_NEGATIVE_KEYWORDS), | |
| ) | |
| parser.add_argument( | |
| "--output_json", | |
| type=Path, | |
| default=ROOT / "archive" / "qwen_long_retention_compare.json", | |
| ) | |
| parser.add_argument( | |
| "--output_md", | |
| type=Path, | |
| default=ROOT / "docs" / "research" / "qwen_long_retention_compare.md", | |
| ) | |
| args = parser.parse_args() | |
| torch.manual_seed(args.seed) | |
| cfg = replace( | |
| TOY_CONFIG, | |
| anchor_threshold=0.10, | |
| anchor_revision_threshold=0.35, | |
| anchor_contradiction_threshold=0.20, | |
| anchor_dead_end_threshold=0.50, | |
| ) | |
| overlay = QwenAnchorOverlay.from_pretrained( | |
| model_name=args.model, | |
| cfg=cfg, | |
| device=args.device, | |
| torch_dtype=torch.float16 if "cuda" in args.device else None, | |
| ) | |
| overlay.eval() | |
| positive_keywords = _split_csv(args.positive_keywords) | |
| negative_keywords = _split_csv(args.negative_keywords) | |
| base = generate_base( | |
| overlay=overlay, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_new_tokens, | |
| max_length=args.max_length, | |
| ) | |
| anchor = overlay.generate_with_anchor_bias( | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_new_tokens, | |
| max_length=args.max_length, | |
| conflict_threshold=args.conflict_threshold, | |
| bias_scale=args.bias_scale, | |
| greedy=True, | |
| repetition_penalty=args.repetition_penalty, | |
| frequency_penalty=args.frequency_penalty, | |
| no_repeat_ngram_size=args.no_repeat_ngram_size, | |
| max_bias_gate_sum=args.max_bias_gate_sum, | |
| entropy_top_k=args.entropy_top_k, | |
| entropy_threshold=args.entropy_threshold, | |
| entropy_slope=args.entropy_slope, | |
| pressure_threshold=args.pressure_threshold, | |
| pressure_slope=args.pressure_slope, | |
| pressure_rescue_floor=args.pressure_rescue_floor, | |
| ) | |
| base_analysis = analyze_keywords( | |
| base["continuation_text"], | |
| positive_keywords=positive_keywords, | |
| negative_keywords=negative_keywords, | |
| ) | |
| anchor_analysis = analyze_keywords( | |
| anchor["continuation_text"], | |
| positive_keywords=positive_keywords, | |
| negative_keywords=negative_keywords, | |
| ) | |
| payload = { | |
| "generated_at": datetime.now(timezone.utc).isoformat(), | |
| "model": args.model, | |
| "device": args.device, | |
| "prompt": args.prompt, | |
| "max_length": args.max_length, | |
| "max_new_tokens": args.max_new_tokens, | |
| "conflict_threshold": args.conflict_threshold, | |
| "bias_scale": args.bias_scale, | |
| "repetition_penalty": args.repetition_penalty, | |
| "frequency_penalty": args.frequency_penalty, | |
| "no_repeat_ngram_size": args.no_repeat_ngram_size, | |
| "max_bias_gate_sum": args.max_bias_gate_sum, | |
| "entropy_top_k": args.entropy_top_k, | |
| "entropy_threshold": args.entropy_threshold, | |
| "entropy_slope": args.entropy_slope, | |
| "pressure_threshold": args.pressure_threshold, | |
| "pressure_slope": args.pressure_slope, | |
| "pressure_rescue_floor": args.pressure_rescue_floor, | |
| "seed": args.seed, | |
| "positive_keywords": positive_keywords, | |
| "negative_keywords": negative_keywords, | |
| "base": base, | |
| "anchor": anchor, | |
| "base_analysis": base_analysis, | |
| "anchor_analysis": anchor_analysis, | |
| } | |
| report = build_markdown_report( | |
| model_name=args.model, | |
| device=args.device, | |
| prompt=args.prompt, | |
| max_new_tokens=args.max_new_tokens, | |
| max_length=args.max_length, | |
| conflict_threshold=args.conflict_threshold, | |
| bias_scale=args.bias_scale, | |
| repetition_penalty=args.repetition_penalty, | |
| frequency_penalty=args.frequency_penalty, | |
| no_repeat_ngram_size=args.no_repeat_ngram_size, | |
| max_bias_gate_sum=args.max_bias_gate_sum, | |
| entropy_top_k=args.entropy_top_k, | |
| entropy_threshold=args.entropy_threshold, | |
| entropy_slope=args.entropy_slope, | |
| pressure_threshold=args.pressure_threshold, | |
| pressure_slope=args.pressure_slope, | |
| pressure_rescue_floor=args.pressure_rescue_floor, | |
| base=base, | |
| anchor=anchor, | |
| base_analysis=base_analysis, | |
| anchor_analysis=anchor_analysis, | |
| ) | |
| args.output_json.parent.mkdir(parents=True, exist_ok=True) | |
| args.output_md.parent.mkdir(parents=True, exist_ok=True) | |
| args.output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") | |
| args.output_md.write_text(report, encoding="utf-8") | |
| print(f"base_lexical_score={base_analysis['lexical_score']:.2f}") | |
| print(f"anchor_lexical_score={anchor_analysis['lexical_score']:.2f}") | |
| print(f"base_quality_score={base_analysis['quality_score']:.2f}") | |
| print(f"anchor_quality_score={anchor_analysis['quality_score']:.2f}") | |
| print( | |
| "anchor_bias_active_steps=" | |
| f"{sum(1 for step in anchor['steps'] if step.get('bias_nonzero_anchors', 0) > 0)}" | |
| ) | |
| print(f"saved_json={args.output_json}") | |
| print(f"saved_md={args.output_md}") | |
| if __name__ == "__main__": | |
| main() | |