mastermap-cleaner / src /llm_router.py
andrewbejjani's picture
Added functional doc in README.md and added basic
c6a3f44
import requests
import time
from tqdm import tqdm
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
class RateLimitException(Exception):
"""Raised when a Groq model hits rate limits and should be rotated."""
pass
class GroqRouter:
"""Small Groq client that rotates through configured fallback models."""
def __init__(self, api_key, available_models):
self.api_key = api_key
self.available_models = available_models
self.current_model_index = 0
self.last_printed_model = None
def ask_judge(self, word, candidates, column_name):
"""Ask Groq to normalize one raw value against likely candidates."""
if self.current_model_index >= len(self.available_models):
return (word, "API_Error_All_Models_Dead", "LOW")
if not self.api_key:
return (word, "No_API_Key", "LOW")
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# Column-specific prompt rules prevent the model from over-splitting
# institutions while still translating geography values to English.
if column_name in ["Institution", "Degree"]:
specific_rules = (
"- Split distinct separate schools or global alliances with a comma (e.g., 'Harvard & MIT' -> 'Harvard University, MIT').\n"
"- Treat alliances like CEMS as distinct institutions! NEVER delete them.\n"
"- Do NOT split a Business School and its Parent University. They are ONE entity.\n"
"- Do NOT split a school and its acronym/alias. Keep them as one entity.\n"
"- Watch out for joint programs! If a hyphen separates TWO truly distinct institutions, split them.\n"
"- Do NOT split single names with hyphens (e.g., 'Paris-Descartes').\n"
"- Strip out ONLY standard degree titles (e.g., 'MSc', 'MBA', 'BBA', 'Ph.D.')."
)
elif column_name in ["Country", "Continent"]:
specific_rules = (
f"- MUST translate the {column_name} name to its standard ENGLISH equivalent (e.g., 'Espagne' -> 'Spain', 'Milano' -> 'Milan').\n"
f"- If the input contains multiple distinct {column_name}s, MUST split them into a comma-separated list.\n"
f"- If it is a single {column_name}, return exactly ONE official English name."
)
else:
specific_rules = (
"- If the input contains multiple distinct entities, MUST split them into a comma-separated list.\n"
"- If it is a single entity, return exactly ONE official name. No aliases."
)
system_msg = (
"You are a data normalization expert. Your task is to map raw input to the MOST ACCURATE "
"official name from a provided list of candidates or your own world knowledge. "
"Respond ONLY with the final string. No conversational text, no quotes."
)
user_msg = (
f"Column: {column_name}\n"
f"Raw Input: '{word}'\n"
f"Database Candidates: {candidates}\n\n"
f"Rules:\n"
f"1. Choose the best match from the Database Candidates if possible.\n"
f"2. If an entity from the Raw Input does NOT exist in the Candidates list, DO NOT map it to a different candidate just because they share a city or country. Instead, return its cleaned original name exactly as written.\n"
f"{specific_rules}\n"
f"3. If gibberish, return 'UNKNOWN'."
)
payload = {
"messages": [{"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}],
"temperature": 0.0,
"max_tokens": 50
}
@retry(
retry=retry_if_exception_type(RateLimitException),
wait=wait_exponential(multiplier=2, min=2, max=30),
stop=stop_after_attempt(5),
reraise=True
)
def fire_request():
"""Fire one request; tenacity retries only explicit rate-limit errors."""
res = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30)
if res.status_code == 429:
raise RateLimitException("Model Limits Exhausted")
return res
while self.current_model_index < len(self.available_models):
active_model = self.available_models[self.current_model_index]
payload["model"] = active_model
if active_model != self.last_printed_model:
tqdm.write(f"\n[INFO] Routing requests to model: {active_model}")
self.last_printed_model = active_model
try:
# Light throttling reduces avoidable rate-limit pressure.
time.sleep(0.3)
response = fire_request()
if response.status_code == 200:
answer = response.json()["choices"][0]["message"]["content"].strip().strip('"\'')
if answer == "UNKNOWN":
return (word, "LLM_Failed", "LOW")
return (answer, "LLM_Judged", "MEDIUM")
else:
tqdm.write(f" [!] Groq API Error: {response.status_code} for '{word}'")
return (word, f"API_Error_{response.status_code}", "LOW")
except RateLimitException:
tqdm.write(f" [!] Limits exhausted for {active_model}!")
# Move to the next configured model and keep processing.
self.current_model_index += 1
if self.current_model_index < len(self.available_models):
next_model = self.available_models[self.current_model_index]
tqdm.write(f" [!] Switching to fallback model: {next_model}...")
else:
tqdm.write(f" [!] FATAL: ALL {len(self.available_models)} MODELS DEAD. Skipping remainder of dataset.")
return (word, "API_Error_All_Models_Dead", "LOW")
except Exception as e:
tqdm.write(f" [!] Request failed for '{word}' on {active_model}: {e}")
return (word, "API_Error_Timeout", "LOW")
return (word, "API_Error_Unhandled", "LOW")