Spaces:
Running
Running
File size: 6,610 Bytes
dc06d4c c6a3f44 dc06d4c c6a3f44 dc06d4c c6a3f44 dc06d4c c6a3f44 dc06d4c c6a3f44 dc06d4c c6a3f44 dc06d4c c6a3f44 dc06d4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import requests
import time
from tqdm import tqdm
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
class RateLimitException(Exception):
"""Raised when a Groq model hits rate limits and should be rotated."""
pass
class GroqRouter:
"""Small Groq client that rotates through configured fallback models."""
def __init__(self, api_key, available_models):
self.api_key = api_key
self.available_models = available_models
self.current_model_index = 0
self.last_printed_model = None
def ask_judge(self, word, candidates, column_name):
"""Ask Groq to normalize one raw value against likely candidates."""
if self.current_model_index >= len(self.available_models):
return (word, "API_Error_All_Models_Dead", "LOW")
if not self.api_key:
return (word, "No_API_Key", "LOW")
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
# Column-specific prompt rules prevent the model from over-splitting
# institutions while still translating geography values to English.
if column_name in ["Institution", "Degree"]:
specific_rules = (
"- Split distinct separate schools or global alliances with a comma (e.g., 'Harvard & MIT' -> 'Harvard University, MIT').\n"
"- Treat alliances like CEMS as distinct institutions! NEVER delete them.\n"
"- Do NOT split a Business School and its Parent University. They are ONE entity.\n"
"- Do NOT split a school and its acronym/alias. Keep them as one entity.\n"
"- Watch out for joint programs! If a hyphen separates TWO truly distinct institutions, split them.\n"
"- Do NOT split single names with hyphens (e.g., 'Paris-Descartes').\n"
"- Strip out ONLY standard degree titles (e.g., 'MSc', 'MBA', 'BBA', 'Ph.D.')."
)
elif column_name in ["Country", "Continent"]:
specific_rules = (
f"- MUST translate the {column_name} name to its standard ENGLISH equivalent (e.g., 'Espagne' -> 'Spain', 'Milano' -> 'Milan').\n"
f"- If the input contains multiple distinct {column_name}s, MUST split them into a comma-separated list.\n"
f"- If it is a single {column_name}, return exactly ONE official English name."
)
else:
specific_rules = (
"- If the input contains multiple distinct entities, MUST split them into a comma-separated list.\n"
"- If it is a single entity, return exactly ONE official name. No aliases."
)
system_msg = (
"You are a data normalization expert. Your task is to map raw input to the MOST ACCURATE "
"official name from a provided list of candidates or your own world knowledge. "
"Respond ONLY with the final string. No conversational text, no quotes."
)
user_msg = (
f"Column: {column_name}\n"
f"Raw Input: '{word}'\n"
f"Database Candidates: {candidates}\n\n"
f"Rules:\n"
f"1. Choose the best match from the Database Candidates if possible.\n"
f"2. If an entity from the Raw Input does NOT exist in the Candidates list, DO NOT map it to a different candidate just because they share a city or country. Instead, return its cleaned original name exactly as written.\n"
f"{specific_rules}\n"
f"3. If gibberish, return 'UNKNOWN'."
)
payload = {
"messages": [{"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}],
"temperature": 0.0,
"max_tokens": 50
}
@retry(
retry=retry_if_exception_type(RateLimitException),
wait=wait_exponential(multiplier=2, min=2, max=30),
stop=stop_after_attempt(5),
reraise=True
)
def fire_request():
"""Fire one request; tenacity retries only explicit rate-limit errors."""
res = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30)
if res.status_code == 429:
raise RateLimitException("Model Limits Exhausted")
return res
while self.current_model_index < len(self.available_models):
active_model = self.available_models[self.current_model_index]
payload["model"] = active_model
if active_model != self.last_printed_model:
tqdm.write(f"\n[INFO] Routing requests to model: {active_model}")
self.last_printed_model = active_model
try:
# Light throttling reduces avoidable rate-limit pressure.
time.sleep(0.3)
response = fire_request()
if response.status_code == 200:
answer = response.json()["choices"][0]["message"]["content"].strip().strip('"\'')
if answer == "UNKNOWN":
return (word, "LLM_Failed", "LOW")
return (answer, "LLM_Judged", "MEDIUM")
else:
tqdm.write(f" [!] Groq API Error: {response.status_code} for '{word}'")
return (word, f"API_Error_{response.status_code}", "LOW")
except RateLimitException:
tqdm.write(f" [!] Limits exhausted for {active_model}!")
# Move to the next configured model and keep processing.
self.current_model_index += 1
if self.current_model_index < len(self.available_models):
next_model = self.available_models[self.current_model_index]
tqdm.write(f" [!] Switching to fallback model: {next_model}...")
else:
tqdm.write(f" [!] FATAL: ALL {len(self.available_models)} MODELS DEAD. Skipping remainder of dataset.")
return (word, "API_Error_All_Models_Dead", "LOW")
except Exception as e:
tqdm.write(f" [!] Request failed for '{word}' on {active_model}: {e}")
return (word, "API_Error_Timeout", "LOW")
return (word, "API_Error_Unhandled", "LOW")
|