SpaceProbe1 / services /registry.py
a9's picture
Upload 10 files
867e59a verified
import asyncio
import json
import logging
import aiohttp
from typing import Dict, Optional
from datetime import datetime
from services.storage import storage
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RepoConnection:
def __init__(self, namespace: str, repo: str, repo_id: str):
self.namespace = namespace
self.repo = repo
self.repo_id = repo_id
self.events_session: Optional[aiohttp.ClientSession] = None
self.metrics_session: Optional[aiohttp.ClientSession] = None
self.events_task: Optional[asyncio.Task] = None
self.metrics_task: Optional[asyncio.Task] = None
self._running = False
self._metrics_active = False
self._retry_count = 0
self._max_retries = 10
async def start(self):
self._running = True
self.events_task = asyncio.create_task(self._listen_events())
logger.info(f"Started connections for {self.namespace}/{self.repo}")
async def stop(self):
self._running = False
self._metrics_active = False
if self.events_task:
self.events_task.cancel()
try:
await self.events_task
except asyncio.CancelledError:
pass
if self.metrics_task:
self.metrics_task.cancel()
try:
await self.metrics_task
except asyncio.CancelledError:
pass
if self.events_session:
await self.events_session.close()
if self.metrics_session:
await self.metrics_session.close()
logger.info(f"Stopped connections for {self.namespace}/{self.repo}")
async def _listen_events(self):
url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/events"
while self._running:
try:
async with aiohttp.ClientSession() as session:
self.events_session = session
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp:
if resp.status != 200:
raise Exception(f"HTTP {resp.status}")
storage.update_status(self.repo_id, "CONNECTED", "Unknown")
self._retry_count = 0
async for line in resp.content:
if not self._running:
break
decoded = line.decode("utf-8").strip()
if decoded.startswith("data:"):
data_str = decoded[5:].strip()
try:
data = json.loads(data_str)
if not isinstance(data, dict):
continue # skip unexpected payloads
stage = data.get("compute", {}).get("status", {}).get("stage", "Unknown")
storage.update_status(self.repo_id, "CONNECTED", stage)
if stage == "RUNNING" and not self._metrics_active:
self._metrics_active = True
self.metrics_task = asyncio.create_task(self._listen_metrics())
elif stage != "RUNNING" and self._metrics_active:
self._metrics_active = False
if self.metrics_task:
self.metrics_task.cancel()
try:
await self.metrics_task
except asyncio.CancelledError:
pass
except json.JSONDecodeError:
continue
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Events stream error for {self.namespace}/{self.repo}: {e}")
storage.update_status(self.repo_id, "ERROR", "Unknown")
if self._running:
delay = await self._exponential_backoff()
await asyncio.sleep(delay)
async def _listen_metrics(self):
url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/metrics"
while self._running and self._metrics_active:
try:
async with aiohttp.ClientSession() as session:
self.metrics_session = session
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp:
if resp.status != 200:
raise Exception(f"HTTP {resp.status}")
async for line in resp.content:
if not self._running or not self._metrics_active:
break
decoded = line.decode("utf-8").strip()
field, _, value = decoded.partition(':')
if field == 'data':
data = json.loads(value.lstrip(' '))
storage.add_metric(self.repo_id, data)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Metrics stream error for {self.namespace}/{self.repo}: {e}")
if self._running and self._metrics_active:
delay = await self._exponential_backoff()
await asyncio.sleep(delay)
async def _exponential_backoff(self):
delay = min(1.0 * (2 ** self._retry_count), 30.0)
self._retry_count = min(self._retry_count + 1, self._max_retries)
return delay
class ConnectionRegistry:
def __init__(self):
self.connections: Dict[str, RepoConnection] = {}
def register(self, repo):
if repo.id in self.connections:
return
conn = RepoConnection(repo.namespace, repo.repo, repo.id)
self.connections[repo.id] = conn
asyncio.create_task(conn.start())
def unregister(self, repo_id: str):
if repo_id in self.connections:
asyncio.create_task(self.connections[repo_id].stop())
del self.connections[repo_id]
def get(self, repo_id: str) -> Optional[RepoConnection]:
return self.connections.get(repo_id)
registry = ConnectionRegistry()