Spaces:
Sleeping
Sleeping
| import os, uuid, asyncio, json, shutil, tempfile | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import FileResponse, JSONResponse, HTMLResponse, Response | |
| import uvicorn | |
| app = FastAPI() | |
| KAGGLE_USER = os.environ.get("KAGGLE_USERNAME", "") | |
| KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "") | |
| SPACE_URL = os.environ.get("SPACE_URL", "").rstrip("/") | |
| KERNEL_REF = f"{KAGGLE_USER}/trellis-worker" | |
| jobs = {} # job_id -> {status, glb_path} | |
| images = {} # job_id -> bytes | |
| async def home(): | |
| rows = "".join( | |
| f"<tr><td>{jid[:8]}…</td><td>{info['status']}</td></tr>" | |
| for jid, info in jobs.items() | |
| ) | |
| return f""" | |
| <html><body style='font-family:monospace;padding:2rem'> | |
| <h2>🧊 TRELLIS Worker Space</h2> | |
| <table border=1 cellpadding=6> | |
| <tr><th>Job ID</th><th>Status</th></tr> | |
| {rows or '<tr><td colspan=2>No jobs yet</td></tr>'} | |
| </table> | |
| <p>POST /generate with form-data file=image.png to start a job.</p> | |
| </body></html> | |
| """ | |
| async def generate(file: UploadFile = File(...)): | |
| job_id = str(uuid.uuid4()) | |
| images[job_id] = await file.read() | |
| jobs[job_id] = {"status": "queued", "glb_path": None} | |
| asyncio.create_task(trigger_kaggle(job_id)) | |
| return {"job_id": job_id, "status": "queued"} | |
| async def get_image(job_id: str): | |
| if job_id not in images: | |
| return JSONResponse({"error": "not found"}, status_code=404) | |
| return Response(images[job_id], media_type="image/png") | |
| async def status(job_id: str): | |
| return jobs.get(job_id, {"status": "not_found"}) | |
| async def receive_glb(job_id: str, file: UploadFile = File(...)): | |
| glb_path = f"/tmp/{job_id}.glb" | |
| with open(glb_path, "wb") as f: | |
| f.write(await file.read()) | |
| jobs[job_id]["status"] = "done" | |
| jobs[job_id]["glb_path"] = glb_path | |
| images.pop(job_id, None) | |
| return {"ok": True} | |
| async def download(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job or job["status"] != "done": | |
| return JSONResponse({"error": "not ready yet"}, status_code=404) | |
| return FileResponse(job["glb_path"], filename="mesh.glb", | |
| media_type="model/gltf-binary") | |
| def make_notebook(job_id: str, space_url: str) -> str: | |
| """Generate the worker notebook with JOB_ID and SPACE_URL baked in.""" | |
| source = [ | |
| "import os, requests, time\n", | |
| "\n", | |
| f'JOB_ID = "{job_id}"\n', | |
| f'SPACE_URL = "{space_url}"\n', | |
| "\n", | |
| 'print(f"Job: {JOB_ID}")\n', | |
| 'print(f"Space: {SPACE_URL}")\n', | |
| "\n", | |
| "# 1. Download image\n", | |
| 'print("Downloading image...")\n', | |
| 'r = requests.get(f"{SPACE_URL}/image/{JOB_ID}", timeout=30)\n', | |
| "r.raise_for_status()\n", | |
| 'with open("/kaggle/working/input.png", "wb") as f:\n', | |
| " f.write(r.content)\n", | |
| 'print("Image downloaded!")\n', | |
| "\n", | |
| "# 2. Remove background\n", | |
| 'os.system("pip install -q transparent-background gradio_client pillow")\n', | |
| "from transparent_background import Remover\n", | |
| "from PIL import Image\n", | |
| "\n", | |
| 'print("Removing background...")\n', | |
| "remover = Remover()\n", | |
| 'img = Image.open("/kaggle/working/input.png").convert("RGB")\n', | |
| 'out = remover.process(img, type="rgba")\n', | |
| 'out.save("/kaggle/working/input_nobg.png")\n', | |
| 'print("BG removed!")\n', | |
| "\n", | |
| "# 3. Run TRELLIS\n", | |
| "from gradio_client import Client, handle_file\n", | |
| "\n", | |
| "MAX_RETRIES = 3\n", | |
| "result = None\n", | |
| "\n", | |
| "for attempt in range(1, MAX_RETRIES + 1):\n", | |
| " try:\n", | |
| ' print(f"Connecting to TRELLIS (attempt {attempt}/{MAX_RETRIES})...")\n', | |
| ' client = Client("trellis-community/TRELLIS")\n', | |
| ' client.predict(api_name="/start_session")\n', | |
| ' print("Session ready! Generating...")\n', | |
| "\n", | |
| " result = client.predict(\n", | |
| ' image=handle_file("/kaggle/working/input_nobg.png"),\n', | |
| " multiimages=[],\n", | |
| " seed=0,\n", | |
| " ss_guidance_strength=7.5,\n", | |
| " ss_sampling_steps=12,\n", | |
| " slat_guidance_strength=3.0,\n", | |
| " slat_sampling_steps=12,\n", | |
| ' multiimage_algo="stochastic",\n', | |
| " mesh_simplify=0.95,\n", | |
| " texture_size=1024,\n", | |
| ' api_name="/generate_and_extract_glb"\n', | |
| " )\n", | |
| ' print("Generation done!")\n', | |
| " break\n", | |
| "\n", | |
| " except Exception as e:\n", | |
| ' print(f"Attempt {attempt} failed: {e}")\n', | |
| " if attempt < MAX_RETRIES:\n", | |
| " time.sleep(30)\n", | |
| " else:\n", | |
| ' raise RuntimeError(f"TRELLIS failed after {MAX_RETRIES} attempts: {e}")\n', | |
| "\n", | |
| "# 4. POST GLB back\n", | |
| "glb_src = result[1] or result[2]\n", | |
| 'print(f"Sending GLB ({os.path.getsize(glb_src)/1024/1024:.1f} MB)...")\n', | |
| "\n", | |
| "with open(glb_src, 'rb') as f:\n", | |
| " resp = requests.post(\n", | |
| ' f"{SPACE_URL}/receive_glb",\n', | |
| ' params={"job_id": JOB_ID},\n', | |
| ' files={"file": ("mesh.glb", f, "model/gltf-binary")},\n', | |
| " timeout=120\n", | |
| " )\n", | |
| "\n", | |
| "resp.raise_for_status()\n", | |
| 'print(f"GLB delivered! Job {JOB_ID} complete.")\n', | |
| ] | |
| notebook = { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": None, | |
| "id": "6117abdc", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": source, | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3", | |
| }, | |
| "language_info": {"name": "python", "version": "3.12.0"}, | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| } | |
| return json.dumps(notebook, indent=1) | |
| async def trigger_kaggle(job_id: str): | |
| try: | |
| jobs[job_id]["status"] = "running" | |
| work_dir = tempfile.mkdtemp() | |
| # kernel-metadata.json — no env vars needed anymore | |
| meta = { | |
| "id": KERNEL_REF, | |
| "title": "trellis-worker", | |
| "code_file": "worker.ipynb", | |
| "language": "python", | |
| "kernel_type": "notebook", | |
| "is_private": True, | |
| "enable_gpu": False, | |
| "enable_internet": True, | |
| "dataset_sources": [], | |
| "competition_sources": [], | |
| "kernel_sources": [], | |
| } | |
| with open(f"{work_dir}/kernel-metadata.json", "w") as f: | |
| json.dump(meta, f) | |
| # Generate notebook with JOB_ID + SPACE_URL baked in | |
| notebook_json = make_notebook(job_id, SPACE_URL) | |
| with open(f"{work_dir}/worker.ipynb", "w") as f: | |
| f.write(notebook_json) | |
| env = {**os.environ, "KAGGLE_USERNAME": KAGGLE_USER, "KAGGLE_KEY": KAGGLE_KEY} | |
| proc = await asyncio.create_subprocess_exec( | |
| "kaggle", "kernels", "push", "-p", work_dir, | |
| env=env, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| ) | |
| stdout, stderr = await proc.communicate() | |
| print(f"[{job_id[:8]}] Kaggle push: {stdout.decode()} {stderr.decode()}") | |
| for _ in range(60): # wait up to 15 min | |
| await asyncio.sleep(15) | |
| if jobs[job_id]["status"] == "done": | |
| print(f"[{job_id[:8]}] Done!") | |
| return | |
| jobs[job_id]["status"] = "error:timeout" | |
| except Exception as e: | |
| print(f"[{job_id[:8]}] Error: {e}") | |
| jobs[job_id]["status"] = f"error:{e}" | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |