Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Tuple, Dict, List | |
| from openai import AsyncOpenAI | |
| import httpx | |
| from config import settings | |
| from .ollama_client import OllamaClient | |
| from .hf_client import HFClient | |
| class ModelManager: | |
| def __init__(self): | |
| self.ollama = OllamaClient(settings.OLLAMA_BASE_URL, settings.OLLAMA_API_KEY) | |
| self.hf_client = None | |
| hf_token = os.environ.get("HF_TOKEN", "") or settings.HF_TOKEN or "" | |
| if hf_token and hf_token not in ("", "your_huggingface_token_here", "ollama", "hf_YourTokenHere"): | |
| self.hf_client = HFClient(settings.HF_INFERENCE_URL, hf_token) | |
| def get_client(self, agent_id: str) -> Tuple[AsyncOpenAI, str]: | |
| agent_config = next((a for a in settings.AGENTS if a["id"] == agent_id), None) | |
| if not agent_config: | |
| # Fallback for unrecognized agent | |
| agent_config = settings.AGENTS[0] if settings.AGENTS else {"provider": "ollama", "model": "llama3"} | |
| provider = agent_config.get("provider", "ollama") | |
| model_name = os.environ.get("MODEL_NAME", "") or agent_config.get("model", "") | |
| api_base = os.environ.get("API_BASE_URL", "") | |
| api_key = os.environ.get("API_KEY", "") | |
| if api_base and api_key and provider != "openai": | |
| client = AsyncOpenAI( | |
| base_url=api_base, | |
| api_key=api_key | |
| ) | |
| return client, model_name | |
| hf_token = os.environ.get("HF_TOKEN", "") or settings.HF_TOKEN or "" | |
| openai_key = os.environ.get("OPENAI_API_KEY", "") or getattr(settings, "OPENAI_API_KEY", "") | |
| if settings.CUSTOM_MODEL_ENABLED: | |
| if settings.CUSTOM_MODEL_AGENT.lower() in (agent_id.lower(), "both", "all"): | |
| client = AsyncOpenAI( | |
| base_url=settings.CUSTOM_MODEL_BASE_URL, | |
| api_key=settings.CUSTOM_MODEL_API_KEY or "none" | |
| ) | |
| return client, settings.CUSTOM_MODEL_NAME | |
| # Priority: OpenAI > HuggingFace > Ollama | |
| if provider == "openai" and openai_key: | |
| client = AsyncOpenAI(api_key=openai_key, base_url=getattr(settings, "OPENAI_BASE_URL", "https://api.openai.com/v1")) | |
| return client, model_name | |
| if provider == "hf" or not self._is_ollama_available(): | |
| if self.hf_client: | |
| return self.hf_client.get_client(), model_name | |
| if hf_token and hf_token not in ("", "your_huggingface_token_here", "ollama", "hf_YourTokenHere"): | |
| temp_client = HFClient(settings.HF_INFERENCE_URL, hf_token) | |
| return temp_client.get_client(), model_name | |
| if provider == "openrouter" and getattr(settings, "OPENROUTER_API_KEY", ""): | |
| client = AsyncOpenAI(api_key=settings.OPENROUTER_API_KEY, base_url=settings.OPENROUTER_BASE_URL) | |
| return client, model_name | |
| return self.ollama.get_client(), model_name | |
| def _is_ollama_available(self) -> bool: | |
| try: | |
| import socket | |
| host = settings.OLLAMA_BASE_URL.replace("http://", "").replace("https://", "").split(":")[0] | |
| port = 11434 | |
| if ":" in settings.OLLAMA_BASE_URL: | |
| port = int(settings.OLLAMA_BASE_URL.split(":")[-1].split("/")[0]) | |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| sock.settimeout(1) | |
| result = sock.connect_ex((host, port)) | |
| sock.close() | |
| return result == 0 | |
| except: | |
| return False | |
| async def add_custom_model(self, agent_id: str, base_url: str, api_key: str, model_name: str) -> dict: | |
| try: | |
| client = AsyncOpenAI(base_url=base_url, api_key=api_key or "none") | |
| response = await client.chat.completions.create( | |
| model=model_name, | |
| messages=[{"role": "user", "content": "Say 'hello' in exactly one word."}], | |
| max_tokens=10, | |
| timeout=30.0 | |
| ) | |
| if response and response.choices: | |
| env_map = { | |
| "CUSTOM_MODEL_ENABLED": "true", | |
| "CUSTOM_MODEL_BASE_URL": base_url, | |
| "CUSTOM_MODEL_API_KEY": api_key, | |
| "CUSTOM_MODEL_NAME": model_name, | |
| "CUSTOM_MODEL_AGENT": agent_id | |
| } | |
| self._update_env_file(env_map) | |
| settings.CUSTOM_MODEL_ENABLED = True | |
| settings.CUSTOM_MODEL_BASE_URL = base_url | |
| settings.CUSTOM_MODEL_API_KEY = api_key | |
| settings.CUSTOM_MODEL_NAME = model_name | |
| settings.CUSTOM_MODEL_AGENT = agent_id | |
| return {"success": True, "message": "Custom model verified and activated."} | |
| else: | |
| return {"success": False, "message": "Model did not return a valid completion."} | |
| except Exception as e: | |
| return {"success": False, "message": f"Validation failed: {str(e)}"} | |
| async def remove_custom_model(self, agent_id: str): | |
| if settings.CUSTOM_MODEL_AGENT.lower() in (agent_id.lower(), "both"): | |
| env_map = {"CUSTOM_MODEL_ENABLED": "false"} | |
| self._update_env_file(env_map) | |
| settings.CUSTOM_MODEL_ENABLED = False | |
| async def list_available_models(self) -> List[str]: | |
| hf_token = os.environ.get("HF_TOKEN", "") or settings.HF_TOKEN or "" | |
| if hf_token and hf_token not in ("", "your_huggingface_token_here", "ollama", "hf_YourTokenHere"): | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get( | |
| "https://huggingface.co/api/models", | |
| headers={"Authorization": f"Bearer {hf_token}"}, | |
| timeout=30.0 | |
| ) | |
| if resp.status_code == 200: | |
| models = resp.json() | |
| return [m["id"] for m in models[:50]] | |
| except: | |
| pass | |
| return await self.ollama.list_models() | |
| def pull_model(self, model_name: str): | |
| return self.ollama.pull_model(model_name) | |
| def _update_env_file(self, overrides: dict): | |
| env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "default.env") | |
| if not os.path.exists(env_path): | |
| return | |
| with open(env_path, "r") as f: | |
| lines = f.readlines() | |
| new_lines = [] | |
| for line in lines: | |
| updated = False | |
| for k, v in overrides.items(): | |
| if line.startswith(f"{k}="): | |
| new_lines.append(f"{k}={v}\n") | |
| updated = True | |
| break | |
| if not updated: | |
| new_lines.append(line) | |
| with open(env_path, "w") as f: | |
| f.writelines(new_lines) | |
| model_manager = ModelManager() | |