Maksymilian Jankowski commited on
Commit
66ee4db
·
1 Parent(s): 7884075

deploy fix, removed celery and reids

Browse files
Files changed (5) hide show
  1. auth.py +1 -1
  2. requirements.txt +0 -2
  3. routers/user_models.py +25 -40
  4. utils/pubsub.py +52 -0
  5. utils/worker.py +37 -54
auth.py CHANGED
@@ -53,7 +53,7 @@ def get_dev_user() -> Optional[User]:
53
 
54
  try:
55
  # Use DEV_USER_ID if provided, otherwise use a default dev ID
56
- dev_user_id = os.getenv("DEV_USER_ID", "cc310cdf-83af-48e4-a341-d6c2f3b5462d")
57
 
58
  print(f"✅ Creating dev user: {DEV_USER_EMAIL} (ID: {dev_user_id})")
59
  return User(
 
53
 
54
  try:
55
  # Use DEV_USER_ID if provided, otherwise use a default dev ID
56
+ dev_user_id = os.getenv("DEV_USER_ID", "dev-user-id-change-this")
57
 
58
  print(f"✅ Creating dev user: {DEV_USER_EMAIL} (ID: {dev_user_id})")
59
  return User(
requirements.txt CHANGED
@@ -12,6 +12,4 @@ stripe>=5.0.0
12
  python-multipart>=0.0.20
13
  requests>=2.32.3
14
  PyJWT>=2.10.1
15
- celery>=5.3.0
16
- redis>=5.0.0
17
  sse-starlette>=1.3.2
 
12
  python-multipart>=0.0.20
13
  requests>=2.32.3
14
  PyJWT>=2.10.1
 
 
15
  sse-starlette>=1.3.2
routers/user_models.py CHANGED
@@ -9,8 +9,7 @@ import asyncio
9
  from pydantic import BaseModel
10
  from utils.worker import poll_meshy_task
11
  import json
12
- import redis
13
- import redis.asyncio as aioredis
14
 
15
  router = APIRouter(
16
  prefix="/user/models",
@@ -109,42 +108,31 @@ async def progress_update(model_id: str, current_user: User = Depends(get_curren
109
  async def sse_progress(model_id: str, request: Request, current_user: User = Depends(get_current_active_user)):
110
  """
111
  Server-Sent Events endpoint that streams realtime progress updates for a
112
- specific `model_id`. The Celery task publishes JSON messages to the
113
- Redis channel "tasks.progress". We listen on that pub/sub channel and
114
- forward only the events that match the requested model id.
115
  """
116
 
117
- # Establish async Redis connection
118
- redis_url = os.getenv("REDIS_URL", "redis://localhost/0")
119
- redis_conn = aioredis.from_url(redis_url, decode_responses=True)
120
-
121
  async def event_generator():
122
- pubsub = redis_conn.pubsub()
123
- await pubsub.subscribe("tasks.progress")
124
  try:
125
  while True:
126
- # Client has gone away – break the loop to finish response
127
  if await request.is_disconnected():
128
  break
129
 
130
- message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
131
- if message and message.get("type") == "message":
132
- try:
133
- data = json.loads(message["data"])
134
- except json.JSONDecodeError:
135
- data = None
136
 
137
- # Forward only messages for the requested task
138
- if data and data.get("taskId") == model_id:
139
- # Yield SSE-compatible dictionary. `sse-starlette`
140
- # handles conversion to the correct wire format.
141
- yield {"data": json.dumps(data)}
142
 
143
- # Give control back to the event loop
144
  await asyncio.sleep(0.1)
145
  finally:
146
- await pubsub.unsubscribe("tasks.progress")
147
- await pubsub.close()
148
 
149
  return EventSourceResponse(event_generator())
150
 
@@ -167,9 +155,9 @@ async def _check_and_decrement_credits(user_id: str):
167
 
168
 
169
  @router.post("/text-to-3d")
170
- async def text_to_3d(prompt: TextPrompt, current_user: User = Depends(get_current_active_user)):
171
  """Reframe user prompt via OpenAI, create a Meshy Text-to-3D task, and kick
172
- off a Celery poller that streams progress via SSE."""
173
 
174
  # Credit check and decrement
175
  await _check_and_decrement_credits(current_user.id)
@@ -281,23 +269,20 @@ async def text_to_3d(prompt: TextPrompt, current_user: User = Depends(get_curren
281
  except Exception as ex:
282
  logging.warning(f"Failed to update DB with Meshy taskId: {ex}")
283
 
284
- # 6. Kick off Celery poller + send initial SSE event
285
  if generated_model_id:
286
- # publish immediate queued event
287
  try:
288
- redis.Redis(host="localhost", decode_responses=True).publish(
289
- "tasks.progress",
290
- json.dumps({
291
- "taskId": str(generated_model_id),
292
- "progress": 0,
293
- "status": "queued",
294
- }),
295
- )
296
  except Exception as ex:
297
  logging.warning(f"Failed to publish initial SSE event: {ex}")
298
 
299
- # Fire background polling task
300
- poll_meshy_task.delay(generated_model_id, meshy_task_id)
301
 
302
  return {
303
  "generated_model_id": generated_model_id,
 
9
  from pydantic import BaseModel
10
  from utils.worker import poll_meshy_task
11
  import json
12
+ from utils.pubsub import pubsub as inmem_pubsub
 
13
 
14
  router = APIRouter(
15
  prefix="/user/models",
 
108
  async def sse_progress(model_id: str, request: Request, current_user: User = Depends(get_current_active_user)):
109
  """
110
  Server-Sent Events endpoint that streams realtime progress updates for a
111
+ specific `model_id`. The background polling task publishes JSON messages to the
112
+ internal in-memory pub/sub bus. We listen on that bus and forward only
113
+ the events that match the requested task id.
114
  """
115
 
 
 
 
 
116
  async def event_generator():
117
+ # Subscribe to in-memory pub/sub for this task
118
+ queue = await inmem_pubsub.subscribe(model_id)
119
  try:
120
  while True:
 
121
  if await request.is_disconnected():
122
  break
123
 
124
+ try:
125
+ data = await asyncio.wait_for(queue.get(), timeout=1.0)
126
+ except asyncio.TimeoutError:
127
+ data = None
 
 
128
 
129
+ if data:
130
+ yield {"data": json.dumps(data)}
 
 
 
131
 
132
+ # Small sleep to relinquish control
133
  await asyncio.sleep(0.1)
134
  finally:
135
+ await inmem_pubsub.unsubscribe(model_id, queue)
 
136
 
137
  return EventSourceResponse(event_generator())
138
 
 
155
 
156
 
157
  @router.post("/text-to-3d")
158
+ async def text_to_3d(prompt: TextPrompt, background_tasks: BackgroundTasks, current_user: User = Depends(get_current_active_user)):
159
  """Reframe user prompt via OpenAI, create a Meshy Text-to-3D task, and kick
160
+ off a FastAPI background poller that streams progress via SSE."""
161
 
162
  # Credit check and decrement
163
  await _check_and_decrement_credits(current_user.id)
 
269
  except Exception as ex:
270
  logging.warning(f"Failed to update DB with Meshy taskId: {ex}")
271
 
272
+ # 6. Kick off background poller + send initial SSE event
273
  if generated_model_id:
274
+ # publish immediate queued event via in-memory bus
275
  try:
276
+ await inmem_pubsub.publish(str(generated_model_id), {
277
+ "taskId": str(generated_model_id),
278
+ "progress": 0,
279
+ "status": "queued",
280
+ })
 
 
 
281
  except Exception as ex:
282
  logging.warning(f"Failed to publish initial SSE event: {ex}")
283
 
284
+ # Fire background polling task using FastAPI BackgroundTasks
285
+ background_tasks.add_task(poll_meshy_task, generated_model_id, meshy_task_id)
286
 
287
  return {
288
  "generated_model_id": generated_model_id,
utils/pubsub.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections import defaultdict
5
+ from typing import Dict, List, Any
6
+
7
+
8
+ class InMemoryPubSub:
9
+ """A minimal in-process pub/sub helper.
10
+
11
+ Subscribers are simply *asyncio.Queue* instances. Each task_id can have
12
+ many subscribers (e.g. multiple browser tabs connected to the same SSE
13
+ stream). The publisher pushes a message **dict** (already JSON-serialisable)
14
+ and every subscriber receives it.
15
+
16
+ NOTE: This works only inside a single Python process. If you run several
17
+ Uvicorn/Gunicorn workers the events won't be shared between them. Use a
18
+ proper external broker (Redis, NATS, …) for multi-process deployments.
19
+ """
20
+
21
+ def __init__(self) -> None:
22
+ self._subscribers: Dict[str, List[asyncio.Queue]] = defaultdict(list)
23
+ self._lock = asyncio.Lock() # protect subscriber mutations
24
+
25
+ async def publish(self, task_id: str, message: Any) -> None:
26
+ """Broadcast *message* to every subscriber listening to *task_id*."""
27
+ # Work on a snapshot to avoid race conditions if someone unsubscribes
28
+ async with self._lock:
29
+ queues = list(self._subscribers.get(task_id, []))
30
+ for queue in queues:
31
+ await queue.put(message)
32
+
33
+ async def subscribe(self, task_id: str) -> asyncio.Queue:
34
+ """Return a brand-new *asyncio.Queue* subscribed to *task_id*."""
35
+ queue: asyncio.Queue = asyncio.Queue()
36
+ async with self._lock:
37
+ self._subscribers[task_id].append(queue)
38
+ return queue
39
+
40
+ async def unsubscribe(self, task_id: str, queue: asyncio.Queue) -> None:
41
+ """Remove *queue* from *task_id* subscriptions and cleanup."""
42
+ async with self._lock:
43
+ if task_id in self._subscribers and queue in self._subscribers[task_id]:
44
+ self._subscribers[task_id].remove(queue)
45
+ if not self._subscribers[task_id]:
46
+ del self._subscribers[task_id]
47
+
48
+
49
+ # Global singleton used across the code-base
50
+ pubsub = InMemoryPubSub()
51
+
52
+ __all__ = ["pubsub", "InMemoryPubSub"]
utils/worker.py CHANGED
@@ -1,33 +1,14 @@
1
- import os, json, time, random
2
- from celery import Celery
3
- import redis
4
  import httpx
5
  from auth import supabase # re-use existing Supabase client
6
-
7
- R = redis.Redis(host="localhost", decode_responses=True)
8
- celery = Celery("tasks", broker=os.getenv("REDIS_URL", "redis://localhost/0"))
9
-
10
- @celery.task(bind=True)
11
- def external_job(self, task_id: str, payload: dict):
12
- """
13
- Poll a fictitious external API until done, emitting progress every ~2 s.
14
- """
15
- for p in range(0, 101, 10):
16
- # --- call the real external API here ----------------------------
17
- time.sleep(random.uniform(1.5, 2.5)) # <- placeholder
18
- # ----------------------------------------------------------------
19
- status = "running" if p < 100 else "completed"
20
- event = dict(taskId=task_id, progress=p, status=status)
21
- R.publish("tasks.progress", json.dumps(event)) # push to bus
22
- return {"result": "OK"}
23
 
24
  # ---------------------------------------------------------------------------
25
- # Real Meshy Text-to-3D polling task
26
  # ---------------------------------------------------------------------------
27
 
28
- @celery.task(bind=True, max_retries=3, default_retry_delay=30)
29
- def poll_meshy_task(self, generated_model_id: int, meshy_task_id: str):
30
- """Poll Meshy API every few seconds, publish progress to Redis and update DB.
31
 
32
  Parameters
33
  ----------
@@ -40,53 +21,52 @@ def poll_meshy_task(self, generated_model_id: int, meshy_task_id: str):
40
 
41
  meshy_api_key = os.getenv("MESHY_API_KEY")
42
  if not meshy_api_key:
43
- raise RuntimeError("MESHY_API_KEY env-var missing – cannot poll Meshy API")
 
44
 
45
  headers = {"Authorization": f"Bearer {meshy_api_key}"}
46
 
47
- with httpx.Client(timeout=30.0) as client:
48
- last_progress = -1 # to avoid spamming identical values
 
49
  while True:
50
  try:
51
- resp = client.get(
52
  f"https://api.meshy.ai/openapi/v2/text-to-3d/{meshy_task_id}",
53
  headers=headers,
54
  )
55
  except Exception as exc:
56
- # network hiccup retry a few times
57
- self.retry(exc=exc)
 
58
 
59
  if resp.status_code != 200:
60
- # Non-200 from Meshy; treat as terminal failure and notify.
61
- status = "FAILED"
62
  data = {
63
  "taskId": str(generated_model_id),
64
- "progress": last_progress if last_progress >= 0 else 0,
65
- "status": status,
66
  "error": f"Meshy API error {resp.status_code}",
67
  }
68
- R.publish("tasks.progress", json.dumps(data))
69
  break
70
 
71
  payload = resp.json()
72
  status = payload.get("status", "UNKNOWN")
73
  progress = payload.get("progress", 0)
74
 
75
- # Publish only if progress changed or when completed/failed
76
  if progress != last_progress or status in ("SUCCEEDED", "FAILED"):
77
- R.publish(
78
- "tasks.progress",
79
- json.dumps({
80
- "taskId": str(generated_model_id),
81
- "progress": progress,
82
- "status": status.lower(), # 'succeeded' -> 'succeeded'
83
- "model_urls": payload.get("model_urls"),
84
- "thumbnail_url": payload.get("thumbnail_url"),
85
- }),
86
- )
87
  last_progress = progress
88
 
89
- # If finished, update the DB and exit loop
90
  if status == "SUCCEEDED":
91
  try:
92
  supabase.from_("Generated_Models").update({
@@ -94,10 +74,10 @@ def poll_meshy_task(self, generated_model_id: int, meshy_task_id: str):
94
  "status": "COMPLETED",
95
  "updated_at": "now()",
96
  }).eq("generated_model_id", generated_model_id).execute()
97
- except Exception:
98
- # Non-critical continue anyway
99
- pass
100
  break
 
101
  if status == "FAILED":
102
  try:
103
  supabase.from_("Generated_Models").update({
@@ -105,9 +85,12 @@ def poll_meshy_task(self, generated_model_id: int, meshy_task_id: str):
105
  "status": "FAILED",
106
  "updated_at": "now()",
107
  }).eq("generated_model_id", generated_model_id).execute()
108
- except Exception:
109
- pass
110
  break
111
 
112
- # Keep polling every 5 seconds
113
- time.sleep(5)
 
 
 
 
1
+ import os, asyncio, logging
 
 
2
  import httpx
3
  from auth import supabase # re-use existing Supabase client
4
+ from utils.pubsub import pubsub as inmem_pubsub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # ---------------------------------------------------------------------------
7
+ # Real Meshy Text-to-3D polling task – executed as FastAPI BackgroundTask
8
  # ---------------------------------------------------------------------------
9
 
10
+ async def poll_meshy_task(generated_model_id: int, meshy_task_id: str):
11
+ """Poll Meshy API every few seconds, publish progress to in-memory pub/sub and update DB.
 
12
 
13
  Parameters
14
  ----------
 
21
 
22
  meshy_api_key = os.getenv("MESHY_API_KEY")
23
  if not meshy_api_key:
24
+ logging.error("MESHY_API_KEY env-var missing – cannot poll Meshy API")
25
+ return
26
 
27
  headers = {"Authorization": f"Bearer {meshy_api_key}"}
28
 
29
+ last_progress = -1 # to avoid spamming identical values
30
+
31
+ async with httpx.AsyncClient(timeout=30.0) as client:
32
  while True:
33
  try:
34
+ resp = await client.get(
35
  f"https://api.meshy.ai/openapi/v2/text-to-3d/{meshy_task_id}",
36
  headers=headers,
37
  )
38
  except Exception as exc:
39
+ logging.warning(f"Meshy API request failed: {exc}; retrying in 30s")
40
+ await asyncio.sleep(30)
41
+ continue
42
 
43
  if resp.status_code != 200:
44
+ # Meshy returned an error notify and abort.
 
45
  data = {
46
  "taskId": str(generated_model_id),
47
+ "progress": max(last_progress, 0),
48
+ "status": "failed",
49
  "error": f"Meshy API error {resp.status_code}",
50
  }
51
+ await inmem_pubsub.publish(str(generated_model_id), data)
52
  break
53
 
54
  payload = resp.json()
55
  status = payload.get("status", "UNKNOWN")
56
  progress = payload.get("progress", 0)
57
 
58
+ # Publish progress updates (avoid duplicates)
59
  if progress != last_progress or status in ("SUCCEEDED", "FAILED"):
60
+ await inmem_pubsub.publish(str(generated_model_id), {
61
+ "taskId": str(generated_model_id),
62
+ "progress": progress,
63
+ "status": status.lower(),
64
+ "model_urls": payload.get("model_urls"),
65
+ "thumbnail_url": payload.get("thumbnail_url"),
66
+ })
 
 
 
67
  last_progress = progress
68
 
69
+ # Handle terminal states
70
  if status == "SUCCEEDED":
71
  try:
72
  supabase.from_("Generated_Models").update({
 
74
  "status": "COMPLETED",
75
  "updated_at": "now()",
76
  }).eq("generated_model_id", generated_model_id).execute()
77
+ except Exception as exc:
78
+ logging.warning(f"Failed to update DB on success: {exc}")
 
79
  break
80
+
81
  if status == "FAILED":
82
  try:
83
  supabase.from_("Generated_Models").update({
 
85
  "status": "FAILED",
86
  "updated_at": "now()",
87
  }).eq("generated_model_id", generated_model_id).execute()
88
+ except Exception as exc:
89
+ logging.warning(f"Failed to update DB on failure: {exc}")
90
  break
91
 
92
+ # Sleep between polls
93
+ await asyncio.sleep(5)
94
+
95
+ # Provide explicit __all__ for clarity
96
+ __all__ = ["poll_meshy_task"]