Spaces:
Sleeping
Sleeping
| """ | |
| ParliaBench Utilities | |
| Party data, topic lists, orientation mappings, inference helpers, and validation. | |
| Source: speech_generator.py / trainer.py β ParliaBench NTUA 2025 | |
| """ | |
| import re | |
| # βββ Party data (from Config.PARTY_DISTRIBUTION in speech_generator.py) βββββββ | |
| PARTY_DISTRIBUTION = { | |
| "Conservative": {"weight": 0.59, "orientation": "Centre-right"}, | |
| "Labour": {"weight": 0.24, "orientation": "Centre-left"}, | |
| "Scottish National Party": {"weight": 0.05, "orientation": "Centre-left"}, | |
| "Liberal Democrats": {"weight": 0.05, "orientation": "Centre to centre-left"}, | |
| "Crossbench": {"weight": 0.028, "orientation": "Unknown"}, | |
| "Democratic Unionist Party": {"weight": 0.016, "orientation": "Right"}, | |
| "Independent": {"weight": 0.01, "orientation": "Unknown"}, | |
| "Plaid Cymru": {"weight": 0.006, "orientation": "Centre-left to left"}, | |
| "Green Party": {"weight": 0.005, "orientation": "Left"}, | |
| "Non-Affiliated": {"weight": 0.003, "orientation": "Unknown"}, | |
| "Bishops": {"weight": 0.002, "orientation": "Unknown"}, | |
| } | |
| PARTIES = list(PARTY_DISTRIBUTION.keys()) | |
| # House restrictions (from Config.COMMONS_PARTIES / Config.LORDS_PARTIES) | |
| COMMONS_PARTIES = [ | |
| "Conservative", "Labour", "Scottish National Party", "Liberal Democrats", | |
| "Democratic Unionist Party", "Independent", "Plaid Cymru", "Green Party", | |
| ] | |
| LORDS_PARTIES = [ | |
| "Conservative", "Labour", "Liberal Democrats", "Crossbench", | |
| "Non-Affiliated", "Green Party", "Bishops", "Independent", | |
| "Plaid Cymru", "Democratic Unionist Party", | |
| ] | |
| # βββ EuroVoc topic categories (from Config.EUROVOC_TOPICS) ββββββββββββββββββββ | |
| EUROVOC_TOPICS = [ | |
| "POLITICS", "LAW", "AGRICULTURE, FORESTRY AND FISHERIES", | |
| "ENERGY", "ECONOMICS", "ENVIRONMENT", "SOCIAL QUESTIONS", | |
| "EDUCATION AND COMMUNICATIONS", "EMPLOYMENT AND WORKING CONDITIONS", | |
| "TRANSPORT", "INTERNATIONAL RELATIONS", "TRADE", | |
| "PRODUCTION, TECHNOLOGY AND RESEARCH", "EUROPEAN UNION", | |
| "SCIENCE", "GEOGRAPHY", "FINANCE", "BUSINESS AND COMPETITION", | |
| "INDUSTRY", "AGRI-FOODSTUFFS", "INTERNATIONAL ORGANISATIONS", | |
| ] | |
| # βββ Houses βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HOUSES = ["House of Commons", "House of Lords"] | |
| HOUSE_DISTRIBUTION = {"House of Commons": 0.78, "House of Lords": 0.22} | |
| # βββ Generation parameters (from Config in speech_generator.py) βββββββββββββββ | |
| DEFAULT_GEN_PARAMS = { | |
| "temperature": 0.7, | |
| "top_p": 0.85, | |
| "repetition_penalty": 1.2, | |
| "max_new_tokens": 850, | |
| "min_words": 43, # P10 threshold | |
| "max_words": 635, # P90 threshold | |
| } | |
| # βββ Model registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Note: Yi is 6B (not 9B) β from ModelConfig in speech_generator.py. | |
| # Fine-tuned models: LoRA adapters uploaded to HF model repos. | |
| # Baseline models: loaded directly from Unsloth's 4-bit quantised repos. | |
| MODELS = { | |
| # Fine-tuned (LoRA adapters β argyrotsipi HF repos) | |
| "Mistral-7B (fine-tuned)": "argyrotsipi/parliabench-unsloth-mistral-7b-v0.3", | |
| "Llama-3.1-8B (fine-tuned)": "argyrotsipi/parliabench-unsloth-llama-3.1-8b", | |
| "Gemma-2-9B (fine-tuned)": "argyrotsipi/parliabench-unsloth-gemma-2-9b", | |
| "Qwen2-7B (fine-tuned)": "argyrotsipi/parliabench-unsloth-qwen-2-7b", | |
| "Yi-1.5-6B (fine-tuned)": "argyrotsipi/parliabench-unsloth-yi-1.5-6b", | |
| # Baselines (raw 4-bit quantised from Unsloth) | |
| "Mistral-7B (baseline)": "unsloth/mistral-7b-v0.3-bnb-4bit", | |
| "Llama-3.1-8B (baseline)": "unsloth/Meta-Llama-3.1-8B-bnb-4bit", | |
| "Gemma-2-9B (baseline)": "unsloth/gemma-2-9b-bnb-4bit", | |
| "Qwen2-7B (baseline)": "unsloth/Qwen2-7B-bnb-4bit", | |
| "Yi-1.5-6B (baseline)": "unsloth/Yi-1.5-6B-bnb-4bit", | |
| } | |
| # Map display name β model family key (for template + stop-string selection) | |
| MODEL_FAMILY = { | |
| "Mistral-7B (fine-tuned)": "mistral", | |
| "Llama-3.1-8B (fine-tuned)": "llama", | |
| "Gemma-2-9B (fine-tuned)": "gemma", | |
| "Qwen2-7B (fine-tuned)": "qwen", | |
| "Yi-1.5-6B (fine-tuned)": "yi", | |
| "Mistral-7B (baseline)": "mistral", | |
| "Llama-3.1-8B (baseline)": "llama", | |
| "Gemma-2-9B (baseline)": "gemma", | |
| "Qwen2-7B (baseline)": "qwen", | |
| "Yi-1.5-6B (baseline)": "yi", | |
| } | |
| # Stop strings, start/end markers, and tokens to strip | |
| # (from ModelConfig.MODELS in speech_generator.py β exact values) | |
| MODEL_CONFIG = { | |
| "mistral": { | |
| "base_model": "unsloth/mistral-7b-v0.3-bnb-4bit", | |
| "stop_strings": ["</s>", "\n[INST]", "\nContext:", "\nInstruction:"], | |
| "start_marker": "[/INST]", | |
| "end_markers": ["</s>", "\n[INST]", "\nContext:"], | |
| "special_tokens_to_remove": ["</s>", "<s>"], | |
| }, | |
| "llama": { | |
| "base_model": "unsloth/Meta-Llama-3.1-8B-bnb-4bit", | |
| "stop_strings": ["<|eot_id|>", "\n<|start_header_id|>user", | |
| "\nContext:", "\nInstruction:"], | |
| "start_marker": "<|start_header_id|>assistant<|end_header_id|>", | |
| "end_markers": ["<|eot_id|>", "</s>", "<|end_of_text|>", | |
| "\n<|start_header_id|>"], | |
| "special_tokens_to_remove": ["<|eot_id|>", "</s>", "<|end_of_text|>", | |
| "<|start_header_id|>", "<|end_header_id|>"], | |
| }, | |
| "gemma": { | |
| "base_model": "unsloth/gemma-2-9b-bnb-4bit", | |
| "stop_strings": ["<end_of_turn>", "\n<start_of_turn>user", | |
| "\nContext:", "\nInstruction:"], | |
| "start_marker": "<start_of_turn>model", | |
| "end_markers": ["<end_of_turn>", "\n<start_of_turn>user", "\n<bos>"], | |
| "special_tokens_to_remove": ["<end_of_turn>", "<start_of_turn>", "<bos>", "<eos>"], | |
| }, | |
| "qwen": { | |
| "base_model": "unsloth/Qwen2-7B-bnb-4bit", | |
| "stop_strings": ["<|im_end|>", "\n<|im_start|>user", | |
| "\nContext:", "\nInstruction:"], | |
| "start_marker": "<|im_start|>assistant", | |
| "end_markers": ["<|im_end|>", "\n<|im_start|>user", | |
| "\n<|im_start|>system"], | |
| "special_tokens_to_remove": ["<|im_end|>", "<|im_start|>", "<|endoftext|>"], | |
| }, | |
| "yi": { | |
| "base_model": "unsloth/Yi-1.5-6B-bnb-4bit", | |
| "stop_strings": ["<|im_end|>", "\n<|im_start|>user", | |
| "\nContext:", "\nInstruction:"], | |
| "start_marker": "<|im_start|>assistant", | |
| "end_markers": ["<|im_end|>", "\n<|im_start|>user"], | |
| "special_tokens_to_remove": ["<|im_end|>", "<|im_start|>", "<|endoftext|>"], | |
| }, | |
| } | |
| # βββ Helper functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_valid_houses(party: str) -> list: | |
| """Return the allowed houses for a given party.""" | |
| if party not in COMMONS_PARTIES: | |
| return ["House of Lords"] | |
| return HOUSES | |
| def get_orientation(party: str) -> str: | |
| return PARTY_DISTRIBUTION.get(party, {}).get("orientation", "Unknown") | |
| def build_context_string(party: str, topic: str, section: str, | |
| orientation: str, house: str) -> str: | |
| """ | |
| Build the pipe-separated context string used at generation time. | |
| Matches speech_generator.py: context = " | ".join(context_parts) | |
| """ | |
| parts = [ | |
| f"EUROVOC TOPIC: {topic}", | |
| f"SECTION: {section}", | |
| f"PARTY: {party}", | |
| f"POLITICAL ORIENTATION: {orientation}", | |
| f"HOUSE: {house}", | |
| ] | |
| return " | ".join(parts) | |
| def count_tokens_approx(text: str) -> int: | |
| """Rough token estimate (~words Γ 1.3).""" | |
| return int(len(text.split()) * 1.3) | |
| # βββ Speech Validator βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Ported from SpeechValidator in speech_generator.py (9-step logic) | |
| _TEMPLATE_MARKERS = [ | |
| "\nuser", "\nassistant", "\nsystem", "\nmodel", | |
| "user\n", "assistant\n", "system\n", "model\n", | |
| "<s>", "system<|", "|>system", | |
| "Context:", "Instruction:", "EUROVOC TOPIC:", "SECTION:", | |
| "PARTY:", "POLITICAL ORIENTATION:", "HOUSE:", | |
| "<|", "|>", "<s>", "</s>", "<bos>", "<eos>", | |
| "<start_of_turn>", "<end_of_turn>", | |
| "<|im_start|>", "<|im_end|>", | |
| "[INST]", "[/INST]", "Response:", | |
| ] | |
| _CORRUPTION_PATTERNS = [ | |
| "Ξ²β", "Ξ³ΖΒ»", "Ξ²\"", "erusform", "});", "</>", | |
| "β", "β", "β", "β", "γ»", "β", "β", "β", "οΏ½", | |
| "<2mass>", "<3mass>", "<4mass>", | |
| ] | |
| _FORBIDDEN_RANGES = [ | |
| (0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0x3040, 0x309F), | |
| (0x30A0, 0x30FF), (0xAC00, 0xD7AF), (0x0600, 0x06FF), | |
| (0x0400, 0x04FF), (0x0E00, 0x0E7F), (0x2580, 0x259F), | |
| (0x2200, 0x22FF), (0x2300, 0x23FF), | |
| ] | |
| _REFUSAL_PATTERNS = [ | |
| "I am not capable of generating", | |
| "I cannot generate", | |
| "I'm sorry but I cannot", | |
| "This is a Parliamentary Speech generator", | |
| "You are asked to", | |
| ] | |
| def validate_speech(text: str, | |
| min_words: int = DEFAULT_GEN_PARAMS["min_words"], | |
| max_words: int = DEFAULT_GEN_PARAMS["max_words"]) -> tuple: | |
| """ | |
| Validate a generated speech. | |
| Returns (is_valid: bool, reason: str). | |
| """ | |
| if not text or not text.strip(): | |
| return False, "EMPTY_SPEECH" | |
| # Step 1: Template leakage | |
| for marker in _TEMPLATE_MARKERS: | |
| if marker in text: | |
| return False, f"TEMPLATE_LEAK: {marker!r}" | |
| # Step 2: Unicode corruption β specific patterns | |
| for pattern in _CORRUPTION_PATTERNS: | |
| if pattern in text: | |
| return False, f"ENCODING_ERROR: {pattern!r}" | |
| # Step 2b: Forbidden Unicode script ranges | |
| for char in text: | |
| cp = ord(char) | |
| for start, end in _FORBIDDEN_RANGES: | |
| if start <= cp <= end: | |
| return False, f"UNICODE_CORRUPTION: U+{cp:04X}" | |
| # Step 3: Repetition β same word 4+ times consecutively | |
| words = text.split() | |
| for i in range(len(words) - 3): | |
| w = words[i].lower() | |
| if len(w) > 3 and all(words[i + j].lower() == w for j in range(1, 4)): | |
| return False, f"REPETITION: '{w}' Γ 4" | |
| # Step 3b: Repeated sequences of 3β7 words | |
| for seq_len in range(3, 8): | |
| for i in range(len(words) - seq_len * 3): | |
| seq = tuple(w.lower() for w in words[i:i + seq_len]) | |
| count, j = 1, i + seq_len | |
| while j + seq_len <= len(words): | |
| if tuple(w.lower() for w in words[j:j + seq_len]) == seq: | |
| count += 1 | |
| j += seq_len | |
| else: | |
| break | |
| if count > 3: | |
| snippet = " ".join(words[i:i + seq_len]) | |
| return False, f"REPETITION: sequence Γ {count} '{snippet[:30]}'" | |
| # Step 4: Counting pattern | |
| counting = ["first", "second", "third", "fourth", "fifth", | |
| "sixth", "seventh", "eighth", "ninth", "tenth"] | |
| if sum(1 for w in counting if w in text.lower()) > 5: | |
| return False, "REPETITION: counting_pattern" | |
| # Step 5: Length | |
| wc = len(words) | |
| if wc < min_words: | |
| return False, f"TOO_SHORT: {wc} words (min {min_words})" | |
| if wc > max_words: | |
| return False, f"TOO_LONG: {wc} words (max {max_words})" | |
| # Step 6: Concatenated speeches | |
| openings = (text.count("My Lords") + text.count("Mr Speaker") | |
| + text.count("Madam Deputy Speaker")) | |
| if openings >= 4: | |
| return False, f"CONCATENATION: {openings} openings detected" | |
| # Step 7: Corrupted endings | |
| if any(text.endswith(end) for end in ["});", "βββ", "...."]): | |
| return False, "CORRUPTED_ENDING" | |
| # Step 8: Refusal / role confusion | |
| tl = text.lower() | |
| for p in _REFUSAL_PATTERNS: | |
| if tl.startswith(p.lower()): | |
| return False, f"META_REFUSAL: {p[:30]!r}" | |
| return True, "OK" | |