File size: 5,598 Bytes
1195c6b
fcb28e8
1195c6b
 
 
07ccb16
1195c6b
959adae
1195c6b
cc80e6a
959adae
 
 
 
 
 
cc80e6a
 
 
96f6a80
 
959adae
 
cc80e6a
959adae
 
cc80e6a
 
 
 
 
96f6a80
959adae
cc80e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
959adae
cc80e6a
959adae
 
 
 
1195c6b
 
 
 
 
 
 
 
07ccb16
1195c6b
 
 
07ccb16
 
1195c6b
 
07ccb16
1195c6b
 
5f5687e
1195c6b
 
 
 
07ccb16
 
 
 
 
1195c6b
07ccb16
 
 
 
 
 
 
 
 
 
 
1195c6b
07ccb16
 
1195c6b
07ccb16
 
 
 
1195c6b
07ccb16
443282d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07ccb16
 
 
 
 
 
 
 
 
 
 
 
 
 
959adae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import asyncio
import threading
import json
from threading import Lock
from huggingface_hub import HfApi, hf_hub_download
from google import genai
from groq import Groq

from groq import AsyncGroq

# ... (Keep your existing GeminiManager and HFManager code) ...

class GroqManager:
    def __init__(self, api_key):
        self.api_key = api_key
        # We do NOT initialize the Groq client globally here anymore!
        self.client_available = True if self.api_key else False
        
        # Using Llama 3.3 70B for better translation quality
        self.model_flash = "llama-3.3-70b-versatile" 
        self.last_used_model = "Idle"

        if self.client_available:
            print(f"⚡ Groq Manager Connected: Standardized on {self.model_flash}")

    # 2. Convert to an ASYNC function
    async def generate_fast(self, prompt):
        if not self.client_available: 
            raise Exception("Groq API Key not found.")
        
        self.last_used_model = "Llama 3.3 70B (Groq)"

        # 3. INITIALIZE THE CLIENT LOCALLY WITHIN AN ASYNC CONTEXT MANAGER
        # This binds the client to the specific thread's event loop and auto-closes it.
        async with AsyncGroq(api_key=self.api_key) as client:
            chat_completion = await client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": prompt,
                    }
                ],
                model=self.model_flash,
            )
            
            # Create a mock response object so AgentInterpretation doesn't break
            class GroqResponse:
                def __init__(self, text):
                    self.text = text
                    
            return GroqResponse(chat_completion.choices[0].message.content)

    # 4. Convert this to async as well since it calls generate_fast
    async def generate_smart(self, prompt):
        # We route both to the fast model for now to bypass rate limits entirely
        return await self.generate_fast(prompt)

    def get_status_string(self):
        return f"⚡ {self.last_used_model}"
        
class GeminiManager:
    def __init__(self, api_key):
        self.api_key = api_key
        self.client = genai.Client(api_key=self.api_key) if self.api_key else None
        self.model_flash = "gemini-2.0-flash"
        self.last_used_model = "Idle"

        if self.client:
            print(f"🧠 Gemini Manager Connected: Standardized on {self.model_flash}")

    def generate_fast(self, prompt):
        if not self.client: raise Exception("Google API Key not found.")
        self.last_used_model = "Gemini 2.0 Flash"
        return self.client.models.generate_content(model=self.model_flash, contents=prompt)

    def generate_smart(self, prompt):
        return self.generate_fast(prompt)

    def get_status_string(self):
        return f"⚡ {self.last_used_model}"

class HFManager:
    def __init__(self, config):
        self.config = config
        self.api = HfApi(token=self.config.HF_TOKEN)
        self.lock = Lock()

    def pull_datasets(self):
        print("⬇️ Pulling datasets & profiles...")
        try:
            files = self.api.list_repo_files(repo_id=self.config.HF_REPO_ID, repo_type="dataset")
            for f in files:
                target_dir = self.config.PROFILES_DIR if f.endswith(".json") else self.config.DATASET_DIR
                if f.endswith(".csv") or f.endswith(".json") or f.endswith(".jsonl"):
                    hf_hub_download(
                        repo_id=self.config.HF_REPO_ID, 
                        filename=f, 
                        repo_type="dataset", 
                        local_dir=target_dir, 
                        token=self.config.HF_TOKEN
                    )
        except Exception as e:
            print(f"❌ HF Pull Error: {e}")
            self.seed_initial_data()

    def seed_initial_data(self):
        p_path = os.path.join(self.config.PROFILES_DIR, "NSL Lab Trainer.json")
        if not os.path.exists(p_path):
            with open(p_path, 'w') as f: json.dump({"lab_name": "NSL", "jargon": {}, "pragmatic_rules": []}, f)

    def push_update(self, filepath, commit_msg="Update"):
        def _upload_task():
            import asyncio
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            try:
                filename = os.path.basename(filepath)
                self.api.upload_file(
                    path_or_fileobj=filepath,
                    path_in_repo=filename,
                    repo_id=self.config.HF_REPO_ID,
                    repo_type="dataset",
                    commit_message=commit_msg
                )
                print(f"✅ Background Sync Complete: {filename}")
            except Exception as e:
                print(f"❌ HF Push Error: {e}")
            finally:
                loop.close()

        threading.Thread(target=_upload_task, daemon=True).start()
    def upload_audio_sample(self, audio_path, dialect):
        clean_dialect = dialect.strip()
        filename = os.path.basename(audio_path)
        hf_path = f"audio/{clean_dialect}/{filename}"
        try:
            self.api.upload_file(
                path_or_fileobj=audio_path, 
                path_in_repo=hf_path, 
                repo_id=self.config.HF_REPO_ID, 
                repo_type="dataset", 
                commit_message=f"Add audio sample for {clean_dialect}"
            )
            return hf_path
        except Exception as e:
            return None