File size: 6,610 Bytes
dc06d4c
 
 
 
 
c6a3f44
 
 
dc06d4c
 
c6a3f44
 
dc06d4c
 
 
 
 
 
 
c6a3f44
dc06d4c
 
 
 
 
 
 
 
c6a3f44
 
dc06d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6a3f44
dc06d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6a3f44
dc06d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6a3f44
dc06d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import requests
import time
from tqdm import tqdm
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type

class RateLimitException(Exception):
    """Raised when a Groq model hits rate limits and should be rotated."""
    pass

class GroqRouter:
    """Small Groq client that rotates through configured fallback models."""

    def __init__(self, api_key, available_models):
        self.api_key = api_key
        self.available_models = available_models
        self.current_model_index = 0
        self.last_printed_model = None

    def ask_judge(self, word, candidates, column_name):
        """Ask Groq to normalize one raw value against likely candidates."""
        if self.current_model_index >= len(self.available_models):
            return (word, "API_Error_All_Models_Dead", "LOW")

        if not self.api_key: 
            return (word, "No_API_Key", "LOW")
            
        headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
        
        # Column-specific prompt rules prevent the model from over-splitting
        # institutions while still translating geography values to English.
        if column_name in ["Institution", "Degree"]:
            specific_rules = (
                "- Split distinct separate schools or global alliances with a comma (e.g., 'Harvard & MIT' -> 'Harvard University, MIT').\n"
                "- Treat alliances like CEMS as distinct institutions! NEVER delete them.\n"
                "- Do NOT split a Business School and its Parent University. They are ONE entity.\n"
                "- Do NOT split a school and its acronym/alias. Keep them as one entity.\n"
                "- Watch out for joint programs! If a hyphen separates TWO truly distinct institutions, split them.\n"
                "- Do NOT split single names with hyphens (e.g., 'Paris-Descartes').\n"
                "- Strip out ONLY standard degree titles (e.g., 'MSc', 'MBA', 'BBA', 'Ph.D.')."
            )
        elif column_name in ["Country", "Continent"]:
            specific_rules = (
                f"- MUST translate the {column_name} name to its standard ENGLISH equivalent (e.g., 'Espagne' -> 'Spain', 'Milano' -> 'Milan').\n"
                f"- If the input contains multiple distinct {column_name}s, MUST split them into a comma-separated list.\n"
                f"- If it is a single {column_name}, return exactly ONE official English name."
            )
        else:
            specific_rules = (
                "- If the input contains multiple distinct entities, MUST split them into a comma-separated list.\n"
                "- If it is a single entity, return exactly ONE official name. No aliases."
            )

        system_msg = (
            "You are a data normalization expert. Your task is to map raw input to the MOST ACCURATE "
            "official name from a provided list of candidates or your own world knowledge. "
            "Respond ONLY with the final string. No conversational text, no quotes."
        )

        user_msg = (
            f"Column: {column_name}\n"
            f"Raw Input: '{word}'\n"
            f"Database Candidates: {candidates}\n\n"
            f"Rules:\n"
            f"1. Choose the best match from the Database Candidates if possible.\n"
            f"2. If an entity from the Raw Input does NOT exist in the Candidates list, DO NOT map it to a different candidate just because they share a city or country. Instead, return its cleaned original name exactly as written.\n"
            f"{specific_rules}\n"
            f"3. If gibberish, return 'UNKNOWN'."
        )
        
        payload = {
            "messages": [{"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}], 
            "temperature": 0.0, 
            "max_tokens": 50
        }

        @retry(
            retry=retry_if_exception_type(RateLimitException),
            wait=wait_exponential(multiplier=2, min=2, max=30),
            stop=stop_after_attempt(5),
            reraise=True 
        )
        def fire_request():
            """Fire one request; tenacity retries only explicit rate-limit errors."""
            res = requests.post("https://api.groq.com/openai/v1/chat/completions", headers=headers, json=payload, timeout=30)
            
            if res.status_code == 429:
                raise RateLimitException("Model Limits Exhausted")
                
            return res

        while self.current_model_index < len(self.available_models):
            active_model = self.available_models[self.current_model_index]
            payload["model"] = active_model 
            
            if active_model != self.last_printed_model:
                tqdm.write(f"\n[INFO] Routing requests to model: {active_model}")
                self.last_printed_model = active_model
            
            try:
                # Light throttling reduces avoidable rate-limit pressure.
                time.sleep(0.3) 
                response = fire_request()
                
                if response.status_code == 200:
                    answer = response.json()["choices"][0]["message"]["content"].strip().strip('"\'')
                    if answer == "UNKNOWN":
                        return (word, "LLM_Failed", "LOW")
                    return (answer, "LLM_Judged", "MEDIUM")
                    
                else:
                    tqdm.write(f"  [!] Groq API Error: {response.status_code} for '{word}'")
                    return (word, f"API_Error_{response.status_code}", "LOW")
                    
            except RateLimitException:
                tqdm.write(f"  [!] Limits exhausted for {active_model}!")
                
                # Move to the next configured model and keep processing.
                self.current_model_index += 1 
                
                if self.current_model_index < len(self.available_models):
                    next_model = self.available_models[self.current_model_index]
                    tqdm.write(f"  [!] Switching to fallback model: {next_model}...")
                else:
                    tqdm.write(f"  [!] FATAL: ALL {len(self.available_models)} MODELS DEAD. Skipping remainder of dataset.")
                    return (word, "API_Error_All_Models_Dead", "LOW")
                    
            except Exception as e:
                tqdm.write(f"  [!] Request failed for '{word}' on {active_model}: {e}")
                return (word, "API_Error_Timeout", "LOW")
                
        return (word, "API_Error_Unhandled", "LOW")