Spaces:
Running
Running
| import os | |
| import asyncio | |
| import threading | |
| import json | |
| from threading import Lock | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from google import genai | |
| from groq import Groq | |
| from groq import AsyncGroq | |
| # ... (Keep your existing GeminiManager and HFManager code) ... | |
| class GroqManager: | |
| def __init__(self, api_key): | |
| self.api_key = api_key | |
| # We do NOT initialize the Groq client globally here anymore! | |
| self.client_available = True if self.api_key else False | |
| # Using Llama 3.3 70B for better translation quality | |
| self.model_flash = "llama-3.3-70b-versatile" | |
| self.last_used_model = "Idle" | |
| if self.client_available: | |
| print(f"⚡ Groq Manager Connected: Standardized on {self.model_flash}") | |
| # 2. Convert to an ASYNC function | |
| async def generate_fast(self, prompt): | |
| if not self.client_available: | |
| raise Exception("Groq API Key not found.") | |
| self.last_used_model = "Llama 3.3 70B (Groq)" | |
| # 3. INITIALIZE THE CLIENT LOCALLY WITHIN AN ASYNC CONTEXT MANAGER | |
| # This binds the client to the specific thread's event loop and auto-closes it. | |
| async with AsyncGroq(api_key=self.api_key) as client: | |
| chat_completion = await client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ], | |
| model=self.model_flash, | |
| ) | |
| # Create a mock response object so AgentInterpretation doesn't break | |
| class GroqResponse: | |
| def __init__(self, text): | |
| self.text = text | |
| return GroqResponse(chat_completion.choices[0].message.content) | |
| # 4. Convert this to async as well since it calls generate_fast | |
| async def generate_smart(self, prompt): | |
| # We route both to the fast model for now to bypass rate limits entirely | |
| return await self.generate_fast(prompt) | |
| def get_status_string(self): | |
| return f"⚡ {self.last_used_model}" | |
| class GeminiManager: | |
| def __init__(self, api_key): | |
| self.api_key = api_key | |
| self.client = genai.Client(api_key=self.api_key) if self.api_key else None | |
| self.model_flash = "gemini-2.0-flash" | |
| self.last_used_model = "Idle" | |
| if self.client: | |
| print(f"🧠 Gemini Manager Connected: Standardized on {self.model_flash}") | |
| def generate_fast(self, prompt): | |
| if not self.client: raise Exception("Google API Key not found.") | |
| self.last_used_model = "Gemini 2.0 Flash" | |
| return self.client.models.generate_content(model=self.model_flash, contents=prompt) | |
| def generate_smart(self, prompt): | |
| return self.generate_fast(prompt) | |
| def get_status_string(self): | |
| return f"⚡ {self.last_used_model}" | |
| class HFManager: | |
| def __init__(self, config): | |
| self.config = config | |
| self.api = HfApi(token=self.config.HF_TOKEN) | |
| self.lock = Lock() | |
| def pull_datasets(self): | |
| print("⬇️ Pulling datasets & profiles...") | |
| try: | |
| files = self.api.list_repo_files(repo_id=self.config.HF_REPO_ID, repo_type="dataset") | |
| for f in files: | |
| target_dir = self.config.PROFILES_DIR if f.endswith(".json") else self.config.DATASET_DIR | |
| if f.endswith(".csv") or f.endswith(".json") or f.endswith(".jsonl"): | |
| hf_hub_download( | |
| repo_id=self.config.HF_REPO_ID, | |
| filename=f, | |
| repo_type="dataset", | |
| local_dir=target_dir, | |
| token=self.config.HF_TOKEN | |
| ) | |
| except Exception as e: | |
| print(f"❌ HF Pull Error: {e}") | |
| self.seed_initial_data() | |
| def seed_initial_data(self): | |
| p_path = os.path.join(self.config.PROFILES_DIR, "NSL Lab Trainer.json") | |
| if not os.path.exists(p_path): | |
| with open(p_path, 'w') as f: json.dump({"lab_name": "NSL", "jargon": {}, "pragmatic_rules": []}, f) | |
| def push_update(self, filepath, commit_msg="Update"): | |
| def _upload_task(): | |
| import asyncio | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| filename = os.path.basename(filepath) | |
| self.api.upload_file( | |
| path_or_fileobj=filepath, | |
| path_in_repo=filename, | |
| repo_id=self.config.HF_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=commit_msg | |
| ) | |
| print(f"✅ Background Sync Complete: {filename}") | |
| except Exception as e: | |
| print(f"❌ HF Push Error: {e}") | |
| finally: | |
| loop.close() | |
| threading.Thread(target=_upload_task, daemon=True).start() | |
| def upload_audio_sample(self, audio_path, dialect): | |
| clean_dialect = dialect.strip() | |
| filename = os.path.basename(audio_path) | |
| hf_path = f"audio/{clean_dialect}/{filename}" | |
| try: | |
| self.api.upload_file( | |
| path_or_fileobj=audio_path, | |
| path_in_repo=hf_path, | |
| repo_id=self.config.HF_REPO_ID, | |
| repo_type="dataset", | |
| commit_message=f"Add audio sample for {clean_dialect}" | |
| ) | |
| return hf_path | |
| except Exception as e: | |
| return None | |