Spaces:
Running
Running
| import requests | |
| import time | |
| from tqdm import tqdm | |
| from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type | |
| class RateLimitException(Exception): | |
| """Raised when a Groq model hits rate limits and should be rotated.""" | |
| pass | |
| class GroqRouter: | |
| """Small Groq client that rotates through configured fallback models.""" | |
| def __init__(self, api_key, available_models): | |
| self.api_key = api_key | |
| self.available_models = available_models | |
| self.current_model_index = 0 | |
| self.last_printed_model = None | |
| def ask_judge(self, word, candidates, column_name): | |
| """Ask Groq to normalize one raw value against likely candidates.""" | |
| if self.current_model_index >= len(self.available_models): | |
| return (word, "API_Error_All_Models_Dead", "LOW") | |
| if not self.api_key: | |
| return (word, "No_API_Key", "LOW") | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} | |
| # Column-specific prompt rules prevent the model from over-splitting | |
| # institutions while still translating geography values to English. | |
| if column_name in ["Institution", "Degree"]: | |
| specific_rules = ( | |
| "- Split distinct separate schools or global alliances with a comma (e.g., 'Harvard & MIT' -> 'Harvard University, MIT').\n" | |
| "- Treat alliances like CEMS as distinct institutions! NEVER delete them.\n" | |
| "- Do NOT split a Business School and its Parent University. They are ONE entity.\n" | |
| "- Do NOT split a school and its acronym/alias. Keep them as one entity.\n" | |
| "- Watch out for joint programs! If a hyphen separates TWO truly distinct institutions, split them.\n" | |
| "- Do NOT split single names with hyphens (e.g., 'Paris-Descartes').\n" | |
| "- Strip out ONLY standard degree titles (e.g., 'MSc', 'MBA', 'BBA', 'Ph.D.')." | |
| ) | |
| elif column_name in ["Country", "Continent"]: | |
| specific_rules = ( | |
| f"- MUST translate the {column_name} name to its standard ENGLISH equivalent (e.g., 'Espagne' -> 'Spain', 'Milano' -> 'Milan').\n" | |
| f"- If the input contains multiple distinct {column_name}s, MUST split them into a comma-separated list.\n" | |
| f"- If it is a single {column_name}, return exactly ONE official English name." | |
| ) | |
| else: | |
| specific_rules = ( | |
| "- If the input contains multiple distinct entities, MUST split them into a comma-separated list.\n" | |
| "- If it is a single entity, return exactly ONE official name. No aliases." | |
| ) | |
| system_msg = ( | |
| "You are a data normalization expert. Your task is to map raw input to the MOST ACCURATE " | |
| "official name from a provided list of candidates or your own world knowledge. " | |
| "Respond ONLY with the final string. No conversational text, no quotes." | |
| ) | |
| user_msg = ( | |
| f"Column: {column_name}\n" | |
| f"Raw Input: '{word}'\n" | |
| f"Database Candidates: {candidates}\n\n" | |
| f"Rules:\n" | |
| f"1. Choose the best match from the Database Candidates if possible.\n" | |
| f"2. If an entity from the Raw Input does NOT exist in the Candidates list, DO NOT map it to a different candidate just because they share a city or country. Instead, return its cleaned original name exactly as written.\n" | |
| f"{specific_rules}\n" | |
| f"3. If gibberish, return 'UNKNOWN'." | |
| ) | |
| payload = { | |
| "messages": [{"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}], | |
| "temperature": 0.0, | |
| "max_tokens": 50 | |
| } | |
| def fire_request(): | |
| """Fire one request; tenacity retries only explicit rate-limit errors.""" | |
| res = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30) | |
| if res.status_code == 429: | |
| raise RateLimitException("Model Limits Exhausted") | |
| return res | |
| while self.current_model_index < len(self.available_models): | |
| active_model = self.available_models[self.current_model_index] | |
| payload["model"] = active_model | |
| if active_model != self.last_printed_model: | |
| tqdm.write(f"\n[INFO] Routing requests to model: {active_model}") | |
| self.last_printed_model = active_model | |
| try: | |
| # Light throttling reduces avoidable rate-limit pressure. | |
| time.sleep(0.3) | |
| response = fire_request() | |
| if response.status_code == 200: | |
| answer = response.json()["choices"][0]["message"]["content"].strip().strip('"\'') | |
| if answer == "UNKNOWN": | |
| return (word, "LLM_Failed", "LOW") | |
| return (answer, "LLM_Judged", "MEDIUM") | |
| else: | |
| tqdm.write(f" [!] Groq API Error: {response.status_code} for '{word}'") | |
| return (word, f"API_Error_{response.status_code}", "LOW") | |
| except RateLimitException: | |
| tqdm.write(f" [!] Limits exhausted for {active_model}!") | |
| # Move to the next configured model and keep processing. | |
| self.current_model_index += 1 | |
| if self.current_model_index < len(self.available_models): | |
| next_model = self.available_models[self.current_model_index] | |
| tqdm.write(f" [!] Switching to fallback model: {next_model}...") | |
| else: | |
| tqdm.write(f" [!] FATAL: ALL {len(self.available_models)} MODELS DEAD. Skipping remainder of dataset.") | |
| return (word, "API_Error_All_Models_Dead", "LOW") | |
| except Exception as e: | |
| tqdm.write(f" [!] Request failed for '{word}' on {active_model}: {e}") | |
| return (word, "API_Error_Timeout", "LOW") | |
| return (word, "API_Error_Unhandled", "LOW") | |