import asyncio import json import logging import aiohttp from typing import Dict, Optional from datetime import datetime from services.storage import storage logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class RepoConnection: def __init__(self, namespace: str, repo: str, repo_id: str): self.namespace = namespace self.repo = repo self.repo_id = repo_id self.events_session: Optional[aiohttp.ClientSession] = None self.metrics_session: Optional[aiohttp.ClientSession] = None self.events_task: Optional[asyncio.Task] = None self.metrics_task: Optional[asyncio.Task] = None self._running = False self._metrics_active = False self._retry_count = 0 self._max_retries = 10 async def start(self): self._running = True self.events_task = asyncio.create_task(self._listen_events()) logger.info(f"Started connections for {self.namespace}/{self.repo}") async def stop(self): self._running = False self._metrics_active = False if self.events_task: self.events_task.cancel() try: await self.events_task except asyncio.CancelledError: pass if self.metrics_task: self.metrics_task.cancel() try: await self.metrics_task except asyncio.CancelledError: pass if self.events_session: await self.events_session.close() if self.metrics_session: await self.metrics_session.close() logger.info(f"Stopped connections for {self.namespace}/{self.repo}") async def _listen_events(self): url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/events" while self._running: try: async with aiohttp.ClientSession() as session: self.events_session = session async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp: if resp.status != 200: raise Exception(f"HTTP {resp.status}") storage.update_status(self.repo_id, "CONNECTED", "Unknown") self._retry_count = 0 async for line in resp.content: if not self._running: break decoded = line.decode("utf-8").strip() if decoded.startswith("data:"): data_str = decoded[5:].strip() try: data = json.loads(data_str) if not isinstance(data, dict): continue # skip unexpected payloads stage = data.get("compute", {}).get("status", {}).get("stage", "Unknown") storage.update_status(self.repo_id, "CONNECTED", stage) if stage == "RUNNING" and not self._metrics_active: self._metrics_active = True self.metrics_task = asyncio.create_task(self._listen_metrics()) elif stage != "RUNNING" and self._metrics_active: self._metrics_active = False if self.metrics_task: self.metrics_task.cancel() try: await self.metrics_task except asyncio.CancelledError: pass except json.JSONDecodeError: continue except asyncio.CancelledError: break except Exception as e: logger.error(f"Events stream error for {self.namespace}/{self.repo}: {e}") storage.update_status(self.repo_id, "ERROR", "Unknown") if self._running: delay = await self._exponential_backoff() await asyncio.sleep(delay) async def _listen_metrics(self): url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/metrics" while self._running and self._metrics_active: try: async with aiohttp.ClientSession() as session: self.metrics_session = session async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp: if resp.status != 200: raise Exception(f"HTTP {resp.status}") async for line in resp.content: if not self._running or not self._metrics_active: break decoded = line.decode("utf-8").strip() field, _, value = decoded.partition(':') if field == 'data': data = json.loads(value.lstrip(' ')) storage.add_metric(self.repo_id, data) except asyncio.CancelledError: break except Exception as e: logger.error(f"Metrics stream error for {self.namespace}/{self.repo}: {e}") if self._running and self._metrics_active: delay = await self._exponential_backoff() await asyncio.sleep(delay) async def _exponential_backoff(self): delay = min(1.0 * (2 ** self._retry_count), 30.0) self._retry_count = min(self._retry_count + 1, self._max_retries) return delay class ConnectionRegistry: def __init__(self): self.connections: Dict[str, RepoConnection] = {} def register(self, repo): if repo.id in self.connections: return conn = RepoConnection(repo.namespace, repo.repo, repo.id) self.connections[repo.id] = conn asyncio.create_task(conn.start()) def unregister(self, repo_id: str): if repo_id in self.connections: asyncio.create_task(self.connections[repo_id].stop()) del self.connections[repo_id] def get(self, repo_id: str) -> Optional[RepoConnection]: return self.connections.get(repo_id) registry = ConnectionRegistry()