ParliaBench / utils.py
argyrotsipi's picture
Upload 6 files
5262c26 verified
"""
ParliaBench Utilities
Party data, topic lists, orientation mappings, inference helpers, and validation.
Source: speech_generator.py / trainer.py β€” ParliaBench NTUA 2025
"""
import re
# ─── Party data (from Config.PARTY_DISTRIBUTION in speech_generator.py) ───────
PARTY_DISTRIBUTION = {
"Conservative": {"weight": 0.59, "orientation": "Centre-right"},
"Labour": {"weight": 0.24, "orientation": "Centre-left"},
"Scottish National Party": {"weight": 0.05, "orientation": "Centre-left"},
"Liberal Democrats": {"weight": 0.05, "orientation": "Centre to centre-left"},
"Crossbench": {"weight": 0.028, "orientation": "Unknown"},
"Democratic Unionist Party": {"weight": 0.016, "orientation": "Right"},
"Independent": {"weight": 0.01, "orientation": "Unknown"},
"Plaid Cymru": {"weight": 0.006, "orientation": "Centre-left to left"},
"Green Party": {"weight": 0.005, "orientation": "Left"},
"Non-Affiliated": {"weight": 0.003, "orientation": "Unknown"},
"Bishops": {"weight": 0.002, "orientation": "Unknown"},
}
PARTIES = list(PARTY_DISTRIBUTION.keys())
# House restrictions (from Config.COMMONS_PARTIES / Config.LORDS_PARTIES)
COMMONS_PARTIES = [
"Conservative", "Labour", "Scottish National Party", "Liberal Democrats",
"Democratic Unionist Party", "Independent", "Plaid Cymru", "Green Party",
]
LORDS_PARTIES = [
"Conservative", "Labour", "Liberal Democrats", "Crossbench",
"Non-Affiliated", "Green Party", "Bishops", "Independent",
"Plaid Cymru", "Democratic Unionist Party",
]
# ─── EuroVoc topic categories (from Config.EUROVOC_TOPICS) ────────────────────
EUROVOC_TOPICS = [
"POLITICS", "LAW", "AGRICULTURE, FORESTRY AND FISHERIES",
"ENERGY", "ECONOMICS", "ENVIRONMENT", "SOCIAL QUESTIONS",
"EDUCATION AND COMMUNICATIONS", "EMPLOYMENT AND WORKING CONDITIONS",
"TRANSPORT", "INTERNATIONAL RELATIONS", "TRADE",
"PRODUCTION, TECHNOLOGY AND RESEARCH", "EUROPEAN UNION",
"SCIENCE", "GEOGRAPHY", "FINANCE", "BUSINESS AND COMPETITION",
"INDUSTRY", "AGRI-FOODSTUFFS", "INTERNATIONAL ORGANISATIONS",
]
# ─── Houses ───────────────────────────────────────────────────────────────────
HOUSES = ["House of Commons", "House of Lords"]
HOUSE_DISTRIBUTION = {"House of Commons": 0.78, "House of Lords": 0.22}
# ─── Generation parameters (from Config in speech_generator.py) ───────────────
DEFAULT_GEN_PARAMS = {
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"max_new_tokens": 850,
"min_words": 43, # P10 threshold
"max_words": 635, # P90 threshold
}
# ─── Model registry ────────────────────────────────────────────────────────────
# Note: Yi is 6B (not 9B) β€” from ModelConfig in speech_generator.py.
# Fine-tuned models: LoRA adapters uploaded to HF model repos.
# Baseline models: loaded directly from Unsloth's 4-bit quantised repos.
MODELS = {
# Fine-tuned (LoRA adapters β€” argyrotsipi HF repos)
"Mistral-7B (fine-tuned)": "argyrotsipi/parliabench-unsloth-mistral-7b-v0.3",
"Llama-3.1-8B (fine-tuned)": "argyrotsipi/parliabench-unsloth-llama-3.1-8b",
"Gemma-2-9B (fine-tuned)": "argyrotsipi/parliabench-unsloth-gemma-2-9b",
"Qwen2-7B (fine-tuned)": "argyrotsipi/parliabench-unsloth-qwen-2-7b",
"Yi-1.5-6B (fine-tuned)": "argyrotsipi/parliabench-unsloth-yi-1.5-6b",
# Baselines (raw 4-bit quantised from Unsloth)
"Mistral-7B (baseline)": "unsloth/mistral-7b-v0.3-bnb-4bit",
"Llama-3.1-8B (baseline)": "unsloth/Meta-Llama-3.1-8B-bnb-4bit",
"Gemma-2-9B (baseline)": "unsloth/gemma-2-9b-bnb-4bit",
"Qwen2-7B (baseline)": "unsloth/Qwen2-7B-bnb-4bit",
"Yi-1.5-6B (baseline)": "unsloth/Yi-1.5-6B-bnb-4bit",
}
# Map display name β†’ model family key (for template + stop-string selection)
MODEL_FAMILY = {
"Mistral-7B (fine-tuned)": "mistral",
"Llama-3.1-8B (fine-tuned)": "llama",
"Gemma-2-9B (fine-tuned)": "gemma",
"Qwen2-7B (fine-tuned)": "qwen",
"Yi-1.5-6B (fine-tuned)": "yi",
"Mistral-7B (baseline)": "mistral",
"Llama-3.1-8B (baseline)": "llama",
"Gemma-2-9B (baseline)": "gemma",
"Qwen2-7B (baseline)": "qwen",
"Yi-1.5-6B (baseline)": "yi",
}
# Stop strings, start/end markers, and tokens to strip
# (from ModelConfig.MODELS in speech_generator.py β€” exact values)
MODEL_CONFIG = {
"mistral": {
"base_model": "unsloth/mistral-7b-v0.3-bnb-4bit",
"stop_strings": ["</s>", "\n[INST]", "\nContext:", "\nInstruction:"],
"start_marker": "[/INST]",
"end_markers": ["</s>", "\n[INST]", "\nContext:"],
"special_tokens_to_remove": ["</s>", "<s>"],
},
"llama": {
"base_model": "unsloth/Meta-Llama-3.1-8B-bnb-4bit",
"stop_strings": ["<|eot_id|>", "\n<|start_header_id|>user",
"\nContext:", "\nInstruction:"],
"start_marker": "<|start_header_id|>assistant<|end_header_id|>",
"end_markers": ["<|eot_id|>", "</s>", "<|end_of_text|>",
"\n<|start_header_id|>"],
"special_tokens_to_remove": ["<|eot_id|>", "</s>", "<|end_of_text|>",
"<|start_header_id|>", "<|end_header_id|>"],
},
"gemma": {
"base_model": "unsloth/gemma-2-9b-bnb-4bit",
"stop_strings": ["<end_of_turn>", "\n<start_of_turn>user",
"\nContext:", "\nInstruction:"],
"start_marker": "<start_of_turn>model",
"end_markers": ["<end_of_turn>", "\n<start_of_turn>user", "\n<bos>"],
"special_tokens_to_remove": ["<end_of_turn>", "<start_of_turn>", "<bos>", "<eos>"],
},
"qwen": {
"base_model": "unsloth/Qwen2-7B-bnb-4bit",
"stop_strings": ["<|im_end|>", "\n<|im_start|>user",
"\nContext:", "\nInstruction:"],
"start_marker": "<|im_start|>assistant",
"end_markers": ["<|im_end|>", "\n<|im_start|>user",
"\n<|im_start|>system"],
"special_tokens_to_remove": ["<|im_end|>", "<|im_start|>", "<|endoftext|>"],
},
"yi": {
"base_model": "unsloth/Yi-1.5-6B-bnb-4bit",
"stop_strings": ["<|im_end|>", "\n<|im_start|>user",
"\nContext:", "\nInstruction:"],
"start_marker": "<|im_start|>assistant",
"end_markers": ["<|im_end|>", "\n<|im_start|>user"],
"special_tokens_to_remove": ["<|im_end|>", "<|im_start|>", "<|endoftext|>"],
},
}
# ─── Helper functions ─────────────────────────────────────────────────────────
def get_valid_houses(party: str) -> list:
"""Return the allowed houses for a given party."""
if party not in COMMONS_PARTIES:
return ["House of Lords"]
return HOUSES
def get_orientation(party: str) -> str:
return PARTY_DISTRIBUTION.get(party, {}).get("orientation", "Unknown")
def build_context_string(party: str, topic: str, section: str,
orientation: str, house: str) -> str:
"""
Build the pipe-separated context string used at generation time.
Matches speech_generator.py: context = " | ".join(context_parts)
"""
parts = [
f"EUROVOC TOPIC: {topic}",
f"SECTION: {section}",
f"PARTY: {party}",
f"POLITICAL ORIENTATION: {orientation}",
f"HOUSE: {house}",
]
return " | ".join(parts)
def count_tokens_approx(text: str) -> int:
"""Rough token estimate (~words Γ— 1.3)."""
return int(len(text.split()) * 1.3)
# ─── Speech Validator ─────────────────────────────────────────────────────────
# Ported from SpeechValidator in speech_generator.py (9-step logic)
_TEMPLATE_MARKERS = [
"\nuser", "\nassistant", "\nsystem", "\nmodel",
"user\n", "assistant\n", "system\n", "model\n",
"<s>", "system<|", "|>system",
"Context:", "Instruction:", "EUROVOC TOPIC:", "SECTION:",
"PARTY:", "POLITICAL ORIENTATION:", "HOUSE:",
"<|", "|>", "<s>", "</s>", "<bos>", "<eos>",
"<start_of_turn>", "<end_of_turn>",
"<|im_start|>", "<|im_end|>",
"[INST]", "[/INST]", "Response:",
]
_CORRUPTION_PATTERNS = [
"β–", "Ξ³Ζ’Β»", "Ξ²\"", "erusform", "});", "</>",
"▍", "β–Œ", "β–Š", "β–ˆ", "・", "━", "┃", "β”œ", "οΏ½",
"<2mass>", "<3mass>", "<4mass>",
]
_FORBIDDEN_RANGES = [
(0x4E00, 0x9FFF), (0x3400, 0x4DBF), (0x3040, 0x309F),
(0x30A0, 0x30FF), (0xAC00, 0xD7AF), (0x0600, 0x06FF),
(0x0400, 0x04FF), (0x0E00, 0x0E7F), (0x2580, 0x259F),
(0x2200, 0x22FF), (0x2300, 0x23FF),
]
_REFUSAL_PATTERNS = [
"I am not capable of generating",
"I cannot generate",
"I'm sorry but I cannot",
"This is a Parliamentary Speech generator",
"You are asked to",
]
def validate_speech(text: str,
min_words: int = DEFAULT_GEN_PARAMS["min_words"],
max_words: int = DEFAULT_GEN_PARAMS["max_words"]) -> tuple:
"""
Validate a generated speech.
Returns (is_valid: bool, reason: str).
"""
if not text or not text.strip():
return False, "EMPTY_SPEECH"
# Step 1: Template leakage
for marker in _TEMPLATE_MARKERS:
if marker in text:
return False, f"TEMPLATE_LEAK: {marker!r}"
# Step 2: Unicode corruption β€” specific patterns
for pattern in _CORRUPTION_PATTERNS:
if pattern in text:
return False, f"ENCODING_ERROR: {pattern!r}"
# Step 2b: Forbidden Unicode script ranges
for char in text:
cp = ord(char)
for start, end in _FORBIDDEN_RANGES:
if start <= cp <= end:
return False, f"UNICODE_CORRUPTION: U+{cp:04X}"
# Step 3: Repetition β€” same word 4+ times consecutively
words = text.split()
for i in range(len(words) - 3):
w = words[i].lower()
if len(w) > 3 and all(words[i + j].lower() == w for j in range(1, 4)):
return False, f"REPETITION: '{w}' Γ— 4"
# Step 3b: Repeated sequences of 3–7 words
for seq_len in range(3, 8):
for i in range(len(words) - seq_len * 3):
seq = tuple(w.lower() for w in words[i:i + seq_len])
count, j = 1, i + seq_len
while j + seq_len <= len(words):
if tuple(w.lower() for w in words[j:j + seq_len]) == seq:
count += 1
j += seq_len
else:
break
if count > 3:
snippet = " ".join(words[i:i + seq_len])
return False, f"REPETITION: sequence Γ— {count} '{snippet[:30]}'"
# Step 4: Counting pattern
counting = ["first", "second", "third", "fourth", "fifth",
"sixth", "seventh", "eighth", "ninth", "tenth"]
if sum(1 for w in counting if w in text.lower()) > 5:
return False, "REPETITION: counting_pattern"
# Step 5: Length
wc = len(words)
if wc < min_words:
return False, f"TOO_SHORT: {wc} words (min {min_words})"
if wc > max_words:
return False, f"TOO_LONG: {wc} words (max {max_words})"
# Step 6: Concatenated speeches
openings = (text.count("My Lords") + text.count("Mr Speaker")
+ text.count("Madam Deputy Speaker"))
if openings >= 4:
return False, f"CONCATENATION: {openings} openings detected"
# Step 7: Corrupted endings
if any(text.endswith(end) for end in ["});", "▍▍▍", "...."]):
return False, "CORRUPTED_ENDING"
# Step 8: Refusal / role confusion
tl = text.lower()
for p in _REFUSAL_PATTERNS:
if tl.startswith(p.lower()):
return False, f"META_REFUSAL: {p[:30]!r}"
return True, "OK"