""" ParliaBench Utilities Party data, topic lists, orientation mappings, inference helpers, and validation. Source: speech_generator.py / trainer.py — ParliaBench NTUA 2025 """ import re # ─── Party data (from Config.PARTY_DISTRIBUTION in speech_generator.py) ─────── PARTY_DISTRIBUTION = { "Conservative": {"weight": 0.59, "orientation": "Centre-right"}, "Labour": {"weight": 0.24, "orientation": "Centre-left"}, "Scottish National Party": {"weight": 0.05, "orientation": "Centre-left"}, "Liberal Democrats": {"weight": 0.05, "orientation": "Centre to centre-left"}, "Crossbench": {"weight": 0.028, "orientation": "Unknown"}, "Democratic Unionist Party": {"weight": 0.016, "orientation": "Right"}, "Independent": {"weight": 0.01, "orientation": "Unknown"}, "Plaid Cymru": {"weight": 0.006, "orientation": "Centre-left to left"}, "Green Party": {"weight": 0.005, "orientation": "Left"}, "Non-Affiliated": {"weight": 0.003, "orientation": "Unknown"}, "Bishops": {"weight": 0.002, "orientation": "Unknown"}, } PARTIES = list(PARTY_DISTRIBUTION.keys()) # House restrictions (from Config.COMMONS_PARTIES / Config.LORDS_PARTIES) COMMONS_PARTIES = [ "Conservative", "Labour", "Scottish National Party", "Liberal Democrats", "Democratic Unionist Party", "Independent", "Plaid Cymru", "Green Party", ] LORDS_PARTIES = [ "Conservative", "Labour", "Liberal Democrats", "Crossbench", "Non-Affiliated", "Green Party", "Bishops", "Independent", "Plaid Cymru", "Democratic Unionist Party", ] # ─── EuroVoc topic categories (from Config.EUROVOC_TOPICS) ──────────────────── EUROVOC_TOPICS = [ "POLITICS", "LAW", "AGRICULTURE, FORESTRY AND FISHERIES", "ENERGY", "ECONOMICS", "ENVIRONMENT", "SOCIAL QUESTIONS", "EDUCATION AND COMMUNICATIONS", "EMPLOYMENT AND WORKING CONDITIONS", "TRANSPORT", "INTERNATIONAL RELATIONS", "TRADE", "PRODUCTION, TECHNOLOGY AND RESEARCH", "EUROPEAN UNION", "SCIENCE", "GEOGRAPHY", "FINANCE", "BUSINESS AND COMPETITION", "INDUSTRY", "AGRI-FOODSTUFFS", "INTERNATIONAL ORGANISATIONS", ] # ─── Houses ─────────────────────────────────────────────────────────────────── HOUSES = ["House of Commons", "House of Lords"] HOUSE_DISTRIBUTION = {"House of Commons": 0.78, "House of Lords": 0.22} # ─── Generation parameters (from Config in speech_generator.py) ─────────────── DEFAULT_GEN_PARAMS = { "temperature": 0.7, "top_p": 0.85, "repetition_penalty": 1.2, "max_new_tokens": 850, "min_words": 43, # P10 threshold "max_words": 635, # P90 threshold } # ─── Model registry ──────────────────────────────────────────────────────────── # Note: Yi is 6B (not 9B) — from ModelConfig in speech_generator.py. # Fine-tuned models: LoRA adapters uploaded to HF model repos. # Baseline models: loaded directly from Unsloth's 4-bit quantised repos. MODELS = { # Fine-tuned (LoRA adapters — argyrotsipi HF repos) "Mistral-7B (fine-tuned)": "argyrotsipi/parliabench-unsloth-mistral-7b-v0.3", "Llama-3.1-8B (fine-tuned)": "argyrotsipi/parliabench-unsloth-llama-3.1-8b", "Gemma-2-9B (fine-tuned)": "argyrotsipi/parliabench-unsloth-gemma-2-9b", "Qwen2-7B (fine-tuned)": "argyrotsipi/parliabench-unsloth-qwen-2-7b", "Yi-1.5-6B (fine-tuned)": "argyrotsipi/parliabench-unsloth-yi-1.5-6b", # Baselines (raw 4-bit quantised from Unsloth) "Mistral-7B (baseline)": "unsloth/mistral-7b-v0.3-bnb-4bit", "Llama-3.1-8B (baseline)": "unsloth/Meta-Llama-3.1-8B-bnb-4bit", "Gemma-2-9B (baseline)": "unsloth/gemma-2-9b-bnb-4bit", "Qwen2-7B (baseline)": "unsloth/Qwen2-7B-bnb-4bit", "Yi-1.5-6B (baseline)": "unsloth/Yi-1.5-6B-bnb-4bit", } # Map display name → model family key (for template + stop-string selection) MODEL_FAMILY = { "Mistral-7B (fine-tuned)": "mistral", "Llama-3.1-8B (fine-tuned)": "llama", "Gemma-2-9B (fine-tuned)": "gemma", "Qwen2-7B (fine-tuned)": "qwen", "Yi-1.5-6B (fine-tuned)": "yi", "Mistral-7B (baseline)": "mistral", "Llama-3.1-8B (baseline)": "llama", "Gemma-2-9B (baseline)": "gemma", "Qwen2-7B (baseline)": "qwen", "Yi-1.5-6B (baseline)": "yi", } # Stop strings, start/end markers, and tokens to strip # (from ModelConfig.MODELS in speech_generator.py — exact values) MODEL_CONFIG = { "mistral": { "base_model": "unsloth/mistral-7b-v0.3-bnb-4bit", "stop_strings": ["", "\n[INST]", "\nContext:", "\nInstruction:"], "start_marker": "[/INST]", "end_markers": ["", "\n[INST]", "\nContext:"], "special_tokens_to_remove": ["", ""], }, "llama": { "base_model": "unsloth/Meta-Llama-3.1-8B-bnb-4bit", "stop_strings": ["<|eot_id|>", "\n<|start_header_id|>user", "\nContext:", "\nInstruction:"], "start_marker": "<|start_header_id|>assistant<|end_header_id|>", "end_markers": ["<|eot_id|>", "", "<|end_of_text|>", "\n<|start_header_id|>"], "special_tokens_to_remove": ["<|eot_id|>", "", "<|end_of_text|>", "<|start_header_id|>", "<|end_header_id|>"], }, "gemma": { "base_model": "unsloth/gemma-2-9b-bnb-4bit", "stop_strings": ["", "\nuser", "\nContext:", "\nInstruction:"], "start_marker": "model", "end_markers": ["", "\nuser", "\n"], "special_tokens_to_remove": ["", "", "", ""], }, "qwen": { "base_model": "unsloth/Qwen2-7B-bnb-4bit", "stop_strings": ["<|im_end|>", "\n<|im_start|>user", "\nContext:", "\nInstruction:"], "start_marker": "<|im_start|>assistant", "end_markers": ["<|im_end|>", "\n<|im_start|>user", "\n<|im_start|>system"], "special_tokens_to_remove": ["<|im_end|>", "<|im_start|>", "<|endoftext|>"], }, "yi": { "base_model": "unsloth/Yi-1.5-6B-bnb-4bit", "stop_strings": ["<|im_end|>", "\n<|im_start|>user", "\nContext:", "\nInstruction:"], "start_marker": "<|im_start|>assistant", "end_markers": ["<|im_end|>", "\n<|im_start|>user"], "special_tokens_to_remove": ["<|im_end|>", "<|im_start|>", "<|endoftext|>"], }, } # ─── Helper functions ───────────────────────────────────────────────────────── def get_valid_houses(party: str) -> list: """Return the allowed houses for a given party.""" if party not in COMMONS_PARTIES: return ["House of Lords"] return HOUSES def get_orientation(party: str) -> str: return PARTY_DISTRIBUTION.get(party, {}).get("orientation", "Unknown") def build_context_string(party: str, topic: str, section: str, orientation: str, house: str) -> str: """ Build the pipe-separated context string used at generation time. Matches speech_generator.py: context = " | ".join(context_parts) """ parts = [ f"EUROVOC TOPIC: {topic}", f"SECTION: {section}", f"PARTY: {party}", f"POLITICAL ORIENTATION: {orientation}", f"HOUSE: {house}", ] return " | ".join(parts) def count_tokens_approx(text: str) -> int: """Rough token estimate (~words × 1.3).""" return int(len(text.split()) * 1.3) # ─── Speech Validator ───────────────────────────────────────────────────────── # Ported from SpeechValidator in speech_generator.py (9-step logic) _TEMPLATE_MARKERS = [ "\nuser", "\nassistant", "\nsystem", "\nmodel", "user\n", "assistant\n", "system\n", "model\n", "", "system<|", "|>system", "Context:", "Instruction:", "EUROVOC TOPIC:", "SECTION:", "PARTY:", "POLITICAL ORIENTATION:", "HOUSE:", "<|", "|>", "", "", "", "", "", "", "<|im_start|>", "<|im_end|>", "[INST]", "[/INST]", "Response:", ] _CORRUPTION_PATTERNS = [ "β–", "・", "β\"", "erusform", "});", "", "▍", "▌", "▊", "█", "・", "━", "┃", "├", "�", "<2mass>", "<3mass>", "<4mass>", ] _FORBIDDEN_RANGES = [ (0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0x3040, 0x309F), (0x30A0, 0x30FF), (0xAC00, 0xD7AF), (0x0600, 0x06FF), (0x0400, 0x04FF), (0x0E00, 0x0E7F), (0x2580, 0x259F), (0x2200, 0x22FF), (0x2300, 0x23FF), ] _REFUSAL_PATTERNS = [ "I am not capable of generating", "I cannot generate", "I'm sorry but I cannot", "This is a Parliamentary Speech generator", "You are asked to", ] def validate_speech(text: str, min_words: int = DEFAULT_GEN_PARAMS["min_words"], max_words: int = DEFAULT_GEN_PARAMS["max_words"]) -> tuple: """ Validate a generated speech. Returns (is_valid: bool, reason: str). """ if not text or not text.strip(): return False, "EMPTY_SPEECH" # Step 1: Template leakage for marker in _TEMPLATE_MARKERS: if marker in text: return False, f"TEMPLATE_LEAK: {marker!r}" # Step 2: Unicode corruption — specific patterns for pattern in _CORRUPTION_PATTERNS: if pattern in text: return False, f"ENCODING_ERROR: {pattern!r}" # Step 2b: Forbidden Unicode script ranges for char in text: cp = ord(char) for start, end in _FORBIDDEN_RANGES: if start <= cp <= end: return False, f"UNICODE_CORRUPTION: U+{cp:04X}" # Step 3: Repetition — same word 4+ times consecutively words = text.split() for i in range(len(words) - 3): w = words[i].lower() if len(w) > 3 and all(words[i + j].lower() == w for j in range(1, 4)): return False, f"REPETITION: '{w}' × 4" # Step 3b: Repeated sequences of 3–7 words for seq_len in range(3, 8): for i in range(len(words) - seq_len * 3): seq = tuple(w.lower() for w in words[i:i + seq_len]) count, j = 1, i + seq_len while j + seq_len <= len(words): if tuple(w.lower() for w in words[j:j + seq_len]) == seq: count += 1 j += seq_len else: break if count > 3: snippet = " ".join(words[i:i + seq_len]) return False, f"REPETITION: sequence × {count} '{snippet[:30]}'" # Step 4: Counting pattern counting = ["first", "second", "third", "fourth", "fifth", "sixth", "seventh", "eighth", "ninth", "tenth"] if sum(1 for w in counting if w in text.lower()) > 5: return False, "REPETITION: counting_pattern" # Step 5: Length wc = len(words) if wc < min_words: return False, f"TOO_SHORT: {wc} words (min {min_words})" if wc > max_words: return False, f"TOO_LONG: {wc} words (max {max_words})" # Step 6: Concatenated speeches openings = (text.count("My Lords") + text.count("Mr Speaker") + text.count("Madam Deputy Speaker")) if openings >= 4: return False, f"CONCATENATION: {openings} openings detected" # Step 7: Corrupted endings if any(text.endswith(end) for end in ["});", "▍▍▍", "...."]): return False, "CORRUPTED_ENDING" # Step 8: Refusal / role confusion tl = text.lower() for p in _REFUSAL_PATTERNS: if tl.startswith(p.lower()): return False, f"META_REFUSAL: {p[:30]!r}" return True, "OK"