Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| import aiohttp | |
| from typing import Dict, Optional | |
| from datetime import datetime | |
| from services.storage import storage | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class RepoConnection: | |
| def __init__(self, namespace: str, repo: str, repo_id: str): | |
| self.namespace = namespace | |
| self.repo = repo | |
| self.repo_id = repo_id | |
| self.events_session: Optional[aiohttp.ClientSession] = None | |
| self.metrics_session: Optional[aiohttp.ClientSession] = None | |
| self.events_task: Optional[asyncio.Task] = None | |
| self.metrics_task: Optional[asyncio.Task] = None | |
| self._running = False | |
| self._metrics_active = False | |
| self._retry_count = 0 | |
| self._max_retries = 10 | |
| async def start(self): | |
| self._running = True | |
| self.events_task = asyncio.create_task(self._listen_events()) | |
| logger.info(f"Started connections for {self.namespace}/{self.repo}") | |
| async def stop(self): | |
| self._running = False | |
| self._metrics_active = False | |
| if self.events_task: | |
| self.events_task.cancel() | |
| try: | |
| await self.events_task | |
| except asyncio.CancelledError: | |
| pass | |
| if self.metrics_task: | |
| self.metrics_task.cancel() | |
| try: | |
| await self.metrics_task | |
| except asyncio.CancelledError: | |
| pass | |
| if self.events_session: | |
| await self.events_session.close() | |
| if self.metrics_session: | |
| await self.metrics_session.close() | |
| logger.info(f"Stopped connections for {self.namespace}/{self.repo}") | |
| async def _listen_events(self): | |
| url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/events" | |
| while self._running: | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| self.events_session = session | |
| async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp: | |
| if resp.status != 200: | |
| raise Exception(f"HTTP {resp.status}") | |
| storage.update_status(self.repo_id, "CONNECTED", "Unknown") | |
| self._retry_count = 0 | |
| async for line in resp.content: | |
| if not self._running: | |
| break | |
| decoded = line.decode("utf-8").strip() | |
| if decoded.startswith("data:"): | |
| data_str = decoded[5:].strip() | |
| try: | |
| data = json.loads(data_str) | |
| if not isinstance(data, dict): | |
| continue # skip unexpected payloads | |
| stage = data.get("compute", {}).get("status", {}).get("stage", "Unknown") | |
| storage.update_status(self.repo_id, "CONNECTED", stage) | |
| if stage == "RUNNING" and not self._metrics_active: | |
| self._metrics_active = True | |
| self.metrics_task = asyncio.create_task(self._listen_metrics()) | |
| elif stage != "RUNNING" and self._metrics_active: | |
| self._metrics_active = False | |
| if self.metrics_task: | |
| self.metrics_task.cancel() | |
| try: | |
| await self.metrics_task | |
| except asyncio.CancelledError: | |
| pass | |
| except json.JSONDecodeError: | |
| continue | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"Events stream error for {self.namespace}/{self.repo}: {e}") | |
| storage.update_status(self.repo_id, "ERROR", "Unknown") | |
| if self._running: | |
| delay = await self._exponential_backoff() | |
| await asyncio.sleep(delay) | |
| async def _listen_metrics(self): | |
| url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/metrics" | |
| while self._running and self._metrics_active: | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| self.metrics_session = session | |
| async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp: | |
| if resp.status != 200: | |
| raise Exception(f"HTTP {resp.status}") | |
| async for line in resp.content: | |
| if not self._running or not self._metrics_active: | |
| break | |
| decoded = line.decode("utf-8").strip() | |
| field, _, value = decoded.partition(':') | |
| if field == 'data': | |
| data = json.loads(value.lstrip(' ')) | |
| storage.add_metric(self.repo_id, data) | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"Metrics stream error for {self.namespace}/{self.repo}: {e}") | |
| if self._running and self._metrics_active: | |
| delay = await self._exponential_backoff() | |
| await asyncio.sleep(delay) | |
| async def _exponential_backoff(self): | |
| delay = min(1.0 * (2 ** self._retry_count), 30.0) | |
| self._retry_count = min(self._retry_count + 1, self._max_retries) | |
| return delay | |
| class ConnectionRegistry: | |
| def __init__(self): | |
| self.connections: Dict[str, RepoConnection] = {} | |
| def register(self, repo): | |
| if repo.id in self.connections: | |
| return | |
| conn = RepoConnection(repo.namespace, repo.repo, repo.id) | |
| self.connections[repo.id] = conn | |
| asyncio.create_task(conn.start()) | |
| def unregister(self, repo_id: str): | |
| if repo_id in self.connections: | |
| asyncio.create_task(self.connections[repo_id].stop()) | |
| del self.connections[repo_id] | |
| def get(self, repo_id: str) -> Optional[RepoConnection]: | |
| return self.connections.get(repo_id) | |
| registry = ConnectionRegistry() |