from __future__ import annotations import argparse from dataclasses import replace from datetime import datetime, timezone import json from pathlib import Path import sys from typing import Any import re import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from src.model.config import TOY_CONFIG from src.model.qwen_anchor_overlay import QwenAnchorOverlay DEFAULT_PROMPT = ( "You are a vegan chef. Write a detailed weekly meal plan with recipes for each day." ) DEFAULT_POSITIVE_KEYWORDS = [ "vegan", "plant-based", "tofu", "lentil", "lentils", "chickpea", "chickpeas", "bean", "beans", "vegetable", "vegetables", "mushroom", "mushrooms", "dairy-free", ] DEFAULT_NEGATIVE_KEYWORDS = [ "meat", "chicken", "beef", "pork", "bacon", "fish", "salmon", "tuna", "shrimp", "egg", "eggs", "cheese", "butter", "milk", "cream", "yogurt", "sausage", "ham", ] def _split_csv(value: str) -> list[str]: return [item.strip() for item in value.split(",") if item.strip()] def generate_base( overlay: QwenAnchorOverlay, prompt: str, max_new_tokens: int, max_length: int, ) -> dict[str, Any]: tokenizer = overlay.tokenizer if tokenizer is None: raise ValueError("tokenizer is required") device = next(overlay.parameters()).device encoded = tokenizer( [prompt], padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) input_ids = encoded["input_ids"].to(device) attention_mask = encoded["attention_mask"].to(device) generated = input_ids generated_mask = attention_mask steps: list[dict[str, Any]] = [] for _ in range(max_new_tokens): outputs = overlay.base_model( input_ids=generated, attention_mask=generated_mask, return_dict=True, ) next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) generated_mask = torch.cat( [generated_mask, torch.ones((1, 1), device=device, dtype=generated_mask.dtype)], dim=1, ) token_id = int(next_token.item()) steps.append( { "token_id": token_id, "token_text": tokenizer.decode([token_id], skip_special_tokens=False), } ) if token_id == int(getattr(tokenizer, "eos_token_id", -1)): break if generated.size(1) >= max_length: break continuation_ids = generated[0, input_ids.size(1) :] continuation_text = tokenizer.decode(continuation_ids, skip_special_tokens=True) return { "prompt": prompt, "generated_text": tokenizer.decode(generated[0], skip_special_tokens=True), "continuation_text": continuation_text, "steps": steps, } def analyze_keywords( text: str, positive_keywords: list[str], negative_keywords: list[str], ) -> dict[str, Any]: lowered = text.lower() positive_hits = {token: lowered.count(token) for token in positive_keywords if token in lowered} protected_negative_hits: dict[str, int] = {} negative_hits: dict[str, int] = {} for token in negative_keywords: if token not in lowered: continue total = lowered.count(token) protected = 0 start = 0 while True: idx = lowered.find(token, start) if idx < 0: break prefix = lowered[max(0, idx - 12) : idx] if "vegan " in prefix or "plant-based " in prefix: protected += 1 start = idx + len(token) effective = total - protected if effective > 0: negative_hits[token] = effective if protected > 0: protected_negative_hits[token] = protected first_negative = None for token in negative_keywords: idx = lowered.find(token) if idx >= 0 and (first_negative is None or idx < first_negative["char_index"]): first_negative = {"token": token, "char_index": idx} first_positive = None for token in positive_keywords: idx = lowered.find(token) if idx >= 0 and (first_positive is None or idx < first_positive["char_index"]): first_positive = {"token": token, "char_index": idx} lexical_score = float(sum(positive_hits.values()) - sum(negative_hits.values())) word_tokens = re.findall(r"[a-zA-Z_]+", lowered) bigrams = list(zip(word_tokens, word_tokens[1:])) repeated_bigram_ratio = 0.0 if bigrams: repeated_bigram_ratio = 1.0 - (len(set(bigrams)) / max(len(bigrams), 1)) sentence_candidates = [ sentence.strip() for sentence in re.split(r"[.!?\n]+", lowered) if sentence.strip() ] sentence_counts = {} max_sentence_repeat = 0 for sentence in sentence_candidates: sentence_counts[sentence] = sentence_counts.get(sentence, 0) + 1 max_sentence_repeat = max(max_sentence_repeat, sentence_counts[sentence]) degeneracy_penalty = float(8.0 * repeated_bigram_ratio + 1.5 * max(0, max_sentence_repeat - 1)) quality_score = float(lexical_score - degeneracy_penalty - 1.5 * sum(negative_hits.values())) return { "positive_hits": positive_hits, "negative_hits": negative_hits, "protected_negative_hits": protected_negative_hits, "positive_total": int(sum(positive_hits.values())), "negative_total": int(sum(negative_hits.values())), "first_positive": first_positive, "first_negative": first_negative, "lexical_score": lexical_score, "repeated_bigram_ratio": float(repeated_bigram_ratio), "max_sentence_repeat": int(max_sentence_repeat), "degeneracy_penalty": float(degeneracy_penalty), "quality_score": float(quality_score), } def build_markdown_report( *, model_name: str, device: str, prompt: str, max_new_tokens: int, max_length: int, conflict_threshold: float, bias_scale: float, repetition_penalty: float, frequency_penalty: float, no_repeat_ngram_size: int, max_bias_gate_sum: float, entropy_top_k: int, entropy_threshold: float, entropy_slope: float, pressure_threshold: float, pressure_slope: float, pressure_rescue_floor: float, base: dict[str, Any], anchor: dict[str, Any], base_analysis: dict[str, Any], anchor_analysis: dict[str, Any], ) -> str: active_bias_steps = sum( 1 for step in anchor["steps"] if step.get("bias_nonzero_anchors", 0) > 0 ) lines = [ "# Qwen Long Retention Compare", "", f"Date: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M timezone.utc')}", f"Model: `{model_name}`", f"Device: `{device}`", f"Max new tokens: `{max_new_tokens}`", f"Max length: `{max_length}`", f"Conflict threshold: `{conflict_threshold:.2f}`", f"Bias scale: `{bias_scale:.2f}`", f"Repetition penalty: `{repetition_penalty:.2f}`", f"Frequency penalty: `{frequency_penalty:.2f}`", f"No-repeat ngram size: `{no_repeat_ngram_size}`", f"Max bias gate sum: `{max_bias_gate_sum:.2f}`", f"Entropy top-k: `{entropy_top_k}`", f"Entropy threshold: `{entropy_threshold:.2f}`", f"Entropy slope: `{entropy_slope:.2f}`", f"Pressure threshold: `{pressure_threshold:.2f}`", f"Pressure slope: `{pressure_slope:.2f}`", f"Pressure rescue floor: `{pressure_rescue_floor:.2f}`", "", "## Prompt", "", f"> {prompt}", "", "## Summary", "", f"- Base lexical score: `{base_analysis['lexical_score']:.2f}`", f"- Anchor lexical score: `{anchor_analysis['lexical_score']:.2f}`", f"- Base quality score: `{base_analysis['quality_score']:.2f}`", f"- Anchor quality score: `{anchor_analysis['quality_score']:.2f}`", f"- Base positive hits: `{base_analysis['positive_total']}`", f"- Anchor positive hits: `{anchor_analysis['positive_total']}`", f"- Base negative hits: `{base_analysis['negative_total']}`", f"- Anchor negative hits: `{anchor_analysis['negative_total']}`", f"- Anchor protected negative hits: `{sum(anchor_analysis['protected_negative_hits'].values())}`", f"- Base degeneracy penalty: `{base_analysis['degeneracy_penalty']:.2f}`", f"- Anchor degeneracy penalty: `{anchor_analysis['degeneracy_penalty']:.2f}`", f"- Anchor bias active steps: `{active_bias_steps}`", f"- Continuations identical: `{'yes' if base['continuation_text'] == anchor['continuation_text'] else 'no'}`", "", "## First keyword events", "", f"- Base first positive: `{base_analysis['first_positive']}`", f"- Base first negative: `{base_analysis['first_negative']}`", f"- Anchor first positive: `{anchor_analysis['first_positive']}`", f"- Anchor first negative: `{anchor_analysis['first_negative']}`", "", "## Base continuation", "", base["continuation_text"].strip() or "_empty_", "", "## Anchor-biased continuation", "", anchor["continuation_text"].strip() or "_empty_", "", ] return "\n".join(lines) + "\n" def main() -> None: parser = argparse.ArgumentParser(description="Run a long base vs anchor-biased Qwen compare.") parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT) parser.add_argument("--max_length", type=int, default=1024) parser.add_argument("--max_new_tokens", type=int, default=500) parser.add_argument("--seed", type=int, default=7) parser.add_argument("--conflict_threshold", type=float, default=0.55) parser.add_argument("--bias_scale", type=float, default=1.50) parser.add_argument("--repetition_penalty", type=float, default=1.15) parser.add_argument("--frequency_penalty", type=float, default=0.05) parser.add_argument("--no_repeat_ngram_size", type=int, default=3) parser.add_argument("--max_bias_gate_sum", type=float, default=1.50) parser.add_argument("--entropy_top_k", type=int, default=32) parser.add_argument("--entropy_threshold", type=float, default=0.35) parser.add_argument("--entropy_slope", type=float, default=0.08) parser.add_argument("--pressure_threshold", "--min_bias_pressure", dest="pressure_threshold", type=float, default=0.60) parser.add_argument("--pressure_slope", type=float, default=0.08) parser.add_argument("--pressure_rescue_floor", type=float, default=0.20) parser.add_argument( "--positive_keywords", type=str, default=",".join(DEFAULT_POSITIVE_KEYWORDS), ) parser.add_argument( "--negative_keywords", type=str, default=",".join(DEFAULT_NEGATIVE_KEYWORDS), ) parser.add_argument( "--output_json", type=Path, default=ROOT / "archive" / "qwen_long_retention_compare.json", ) parser.add_argument( "--output_md", type=Path, default=ROOT / "docs" / "research" / "qwen_long_retention_compare.md", ) args = parser.parse_args() torch.manual_seed(args.seed) cfg = replace( TOY_CONFIG, anchor_threshold=0.10, anchor_revision_threshold=0.35, anchor_contradiction_threshold=0.20, anchor_dead_end_threshold=0.50, ) overlay = QwenAnchorOverlay.from_pretrained( model_name=args.model, cfg=cfg, device=args.device, torch_dtype=torch.float16 if "cuda" in args.device else None, ) overlay.eval() positive_keywords = _split_csv(args.positive_keywords) negative_keywords = _split_csv(args.negative_keywords) base = generate_base( overlay=overlay, prompt=args.prompt, max_new_tokens=args.max_new_tokens, max_length=args.max_length, ) anchor = overlay.generate_with_anchor_bias( prompt=args.prompt, max_new_tokens=args.max_new_tokens, max_length=args.max_length, conflict_threshold=args.conflict_threshold, bias_scale=args.bias_scale, greedy=True, repetition_penalty=args.repetition_penalty, frequency_penalty=args.frequency_penalty, no_repeat_ngram_size=args.no_repeat_ngram_size, max_bias_gate_sum=args.max_bias_gate_sum, entropy_top_k=args.entropy_top_k, entropy_threshold=args.entropy_threshold, entropy_slope=args.entropy_slope, pressure_threshold=args.pressure_threshold, pressure_slope=args.pressure_slope, pressure_rescue_floor=args.pressure_rescue_floor, ) base_analysis = analyze_keywords( base["continuation_text"], positive_keywords=positive_keywords, negative_keywords=negative_keywords, ) anchor_analysis = analyze_keywords( anchor["continuation_text"], positive_keywords=positive_keywords, negative_keywords=negative_keywords, ) payload = { "generated_at": datetime.now(timezone.utc).isoformat(), "model": args.model, "device": args.device, "prompt": args.prompt, "max_length": args.max_length, "max_new_tokens": args.max_new_tokens, "conflict_threshold": args.conflict_threshold, "bias_scale": args.bias_scale, "repetition_penalty": args.repetition_penalty, "frequency_penalty": args.frequency_penalty, "no_repeat_ngram_size": args.no_repeat_ngram_size, "max_bias_gate_sum": args.max_bias_gate_sum, "entropy_top_k": args.entropy_top_k, "entropy_threshold": args.entropy_threshold, "entropy_slope": args.entropy_slope, "pressure_threshold": args.pressure_threshold, "pressure_slope": args.pressure_slope, "pressure_rescue_floor": args.pressure_rescue_floor, "seed": args.seed, "positive_keywords": positive_keywords, "negative_keywords": negative_keywords, "base": base, "anchor": anchor, "base_analysis": base_analysis, "anchor_analysis": anchor_analysis, } report = build_markdown_report( model_name=args.model, device=args.device, prompt=args.prompt, max_new_tokens=args.max_new_tokens, max_length=args.max_length, conflict_threshold=args.conflict_threshold, bias_scale=args.bias_scale, repetition_penalty=args.repetition_penalty, frequency_penalty=args.frequency_penalty, no_repeat_ngram_size=args.no_repeat_ngram_size, max_bias_gate_sum=args.max_bias_gate_sum, entropy_top_k=args.entropy_top_k, entropy_threshold=args.entropy_threshold, entropy_slope=args.entropy_slope, pressure_threshold=args.pressure_threshold, pressure_slope=args.pressure_slope, pressure_rescue_floor=args.pressure_rescue_floor, base=base, anchor=anchor, base_analysis=base_analysis, anchor_analysis=anchor_analysis, ) args.output_json.parent.mkdir(parents=True, exist_ok=True) args.output_md.parent.mkdir(parents=True, exist_ok=True) args.output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") args.output_md.write_text(report, encoding="utf-8") print(f"base_lexical_score={base_analysis['lexical_score']:.2f}") print(f"anchor_lexical_score={anchor_analysis['lexical_score']:.2f}") print(f"base_quality_score={base_analysis['quality_score']:.2f}") print(f"anchor_quality_score={anchor_analysis['quality_score']:.2f}") print( "anchor_bias_active_steps=" f"{sum(1 for step in anchor['steps'] if step.get('bias_nonzero_anchors', 0) > 0)}" ) print(f"saved_json={args.output_json}") print(f"saved_md={args.output_md}") if __name__ == "__main__": main()