Spaces:
Sleeping
Sleeping
File size: 6,672 Bytes
9b2dc95 a0bb6c6 9b2dc95 867e59a 9b2dc95 7e1df57 9b2dc95 7e1df57 9b2dc95 ddab84d f548e06 ddab84d f548e06 9b2dc95 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import asyncio
import json
import logging
import aiohttp
from typing import Dict, Optional
from datetime import datetime
from services.storage import storage
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RepoConnection:
def __init__(self, namespace: str, repo: str, repo_id: str):
self.namespace = namespace
self.repo = repo
self.repo_id = repo_id
self.events_session: Optional[aiohttp.ClientSession] = None
self.metrics_session: Optional[aiohttp.ClientSession] = None
self.events_task: Optional[asyncio.Task] = None
self.metrics_task: Optional[asyncio.Task] = None
self._running = False
self._metrics_active = False
self._retry_count = 0
self._max_retries = 10
async def start(self):
self._running = True
self.events_task = asyncio.create_task(self._listen_events())
logger.info(f"Started connections for {self.namespace}/{self.repo}")
async def stop(self):
self._running = False
self._metrics_active = False
if self.events_task:
self.events_task.cancel()
try:
await self.events_task
except asyncio.CancelledError:
pass
if self.metrics_task:
self.metrics_task.cancel()
try:
await self.metrics_task
except asyncio.CancelledError:
pass
if self.events_session:
await self.events_session.close()
if self.metrics_session:
await self.metrics_session.close()
logger.info(f"Stopped connections for {self.namespace}/{self.repo}")
async def _listen_events(self):
url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/events"
while self._running:
try:
async with aiohttp.ClientSession() as session:
self.events_session = session
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp:
if resp.status != 200:
raise Exception(f"HTTP {resp.status}")
storage.update_status(self.repo_id, "CONNECTED", "Unknown")
self._retry_count = 0
async for line in resp.content:
if not self._running:
break
decoded = line.decode("utf-8").strip()
if decoded.startswith("data:"):
data_str = decoded[5:].strip()
try:
data = json.loads(data_str)
if not isinstance(data, dict):
continue # skip unexpected payloads
stage = data.get("compute", {}).get("status", {}).get("stage", "Unknown")
storage.update_status(self.repo_id, "CONNECTED", stage)
if stage == "RUNNING" and not self._metrics_active:
self._metrics_active = True
self.metrics_task = asyncio.create_task(self._listen_metrics())
elif stage != "RUNNING" and self._metrics_active:
self._metrics_active = False
if self.metrics_task:
self.metrics_task.cancel()
try:
await self.metrics_task
except asyncio.CancelledError:
pass
except json.JSONDecodeError:
continue
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Events stream error for {self.namespace}/{self.repo}: {e}")
storage.update_status(self.repo_id, "ERROR", "Unknown")
if self._running:
delay = await self._exponential_backoff()
await asyncio.sleep(delay)
async def _listen_metrics(self):
url = f"https://huggingface.co/api/spaces/{self.namespace}/{self.repo}/metrics"
while self._running and self._metrics_active:
try:
async with aiohttp.ClientSession() as session:
self.metrics_session = session
async with session.get(url, timeout=aiohttp.ClientTimeout(total=60)) as resp:
if resp.status != 200:
raise Exception(f"HTTP {resp.status}")
async for line in resp.content:
if not self._running or not self._metrics_active:
break
decoded = line.decode("utf-8").strip()
field, _, value = decoded.partition(':')
if field == 'data':
data = json.loads(value.lstrip(' '))
storage.add_metric(self.repo_id, data)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Metrics stream error for {self.namespace}/{self.repo}: {e}")
if self._running and self._metrics_active:
delay = await self._exponential_backoff()
await asyncio.sleep(delay)
async def _exponential_backoff(self):
delay = min(1.0 * (2 ** self._retry_count), 30.0)
self._retry_count = min(self._retry_count + 1, self._max_retries)
return delay
class ConnectionRegistry:
def __init__(self):
self.connections: Dict[str, RepoConnection] = {}
def register(self, repo):
if repo.id in self.connections:
return
conn = RepoConnection(repo.namespace, repo.repo, repo.id)
self.connections[repo.id] = conn
asyncio.create_task(conn.start())
def unregister(self, repo_id: str):
if repo_id in self.connections:
asyncio.create_task(self.connections[repo_id].stop())
del self.connections[repo_id]
def get(self, repo_id: str) -> Optional[RepoConnection]:
return self.connections.get(repo_id)
registry = ConnectionRegistry() |