import os, uuid, asyncio, json, shutil, tempfile from fastapi import FastAPI, UploadFile, File from fastapi.responses import FileResponse, JSONResponse, HTMLResponse, Response import uvicorn app = FastAPI() KAGGLE_USER = os.environ.get("KAGGLE_USERNAME", "") KAGGLE_KEY = os.environ.get("KAGGLE_KEY", "") SPACE_URL = os.environ.get("SPACE_URL", "").rstrip("/") KERNEL_REF = f"{KAGGLE_USER}/trellis-worker" jobs = {} # job_id -> {status, glb_path} images = {} # job_id -> bytes @app.get("/", response_class=HTMLResponse) async def home(): rows = "".join( f"{jid[:8]}…{info['status']}" for jid, info in jobs.items() ) return f"""

🧊 TRELLIS Worker Space

{rows or ''}
Job IDStatus
No jobs yet

POST /generate with form-data file=image.png to start a job.

""" @app.post("/generate") async def generate(file: UploadFile = File(...)): job_id = str(uuid.uuid4()) images[job_id] = await file.read() jobs[job_id] = {"status": "queued", "glb_path": None} asyncio.create_task(trigger_kaggle(job_id)) return {"job_id": job_id, "status": "queued"} @app.get("/image/{job_id}") async def get_image(job_id: str): if job_id not in images: return JSONResponse({"error": "not found"}, status_code=404) return Response(images[job_id], media_type="image/png") @app.get("/status/{job_id}") async def status(job_id: str): return jobs.get(job_id, {"status": "not_found"}) @app.post("/receive_glb") async def receive_glb(job_id: str, file: UploadFile = File(...)): glb_path = f"/tmp/{job_id}.glb" with open(glb_path, "wb") as f: f.write(await file.read()) jobs[job_id]["status"] = "done" jobs[job_id]["glb_path"] = glb_path images.pop(job_id, None) return {"ok": True} @app.get("/download/{job_id}") async def download(job_id: str): job = jobs.get(job_id) if not job or job["status"] != "done": return JSONResponse({"error": "not ready yet"}, status_code=404) return FileResponse(job["glb_path"], filename="mesh.glb", media_type="model/gltf-binary") def make_notebook(job_id: str, space_url: str) -> str: """Generate the worker notebook with JOB_ID and SPACE_URL baked in.""" source = [ "import os, requests, time\n", "\n", f'JOB_ID = "{job_id}"\n', f'SPACE_URL = "{space_url}"\n', "\n", 'print(f"Job: {JOB_ID}")\n', 'print(f"Space: {SPACE_URL}")\n', "\n", "# 1. Download image\n", 'print("Downloading image...")\n', 'r = requests.get(f"{SPACE_URL}/image/{JOB_ID}", timeout=30)\n', "r.raise_for_status()\n", 'with open("/kaggle/working/input.png", "wb") as f:\n', " f.write(r.content)\n", 'print("Image downloaded!")\n', "\n", "# 2. Remove background\n", 'os.system("pip install -q transparent-background gradio_client pillow")\n', "from transparent_background import Remover\n", "from PIL import Image\n", "\n", 'print("Removing background...")\n', "remover = Remover()\n", 'img = Image.open("/kaggle/working/input.png").convert("RGB")\n', 'out = remover.process(img, type="rgba")\n', 'out.save("/kaggle/working/input_nobg.png")\n', 'print("BG removed!")\n', "\n", "# 3. Run TRELLIS\n", "from gradio_client import Client, handle_file\n", "\n", "MAX_RETRIES = 3\n", "result = None\n", "\n", "for attempt in range(1, MAX_RETRIES + 1):\n", " try:\n", ' print(f"Connecting to TRELLIS (attempt {attempt}/{MAX_RETRIES})...")\n', ' client = Client("trellis-community/TRELLIS")\n', ' client.predict(api_name="/start_session")\n', ' print("Session ready! Generating...")\n', "\n", " result = client.predict(\n", ' image=handle_file("/kaggle/working/input_nobg.png"),\n', " multiimages=[],\n", " seed=0,\n", " ss_guidance_strength=7.5,\n", " ss_sampling_steps=12,\n", " slat_guidance_strength=3.0,\n", " slat_sampling_steps=12,\n", ' multiimage_algo="stochastic",\n', " mesh_simplify=0.95,\n", " texture_size=1024,\n", ' api_name="/generate_and_extract_glb"\n', " )\n", ' print("Generation done!")\n', " break\n", "\n", " except Exception as e:\n", ' print(f"Attempt {attempt} failed: {e}")\n', " if attempt < MAX_RETRIES:\n", " time.sleep(30)\n", " else:\n", ' raise RuntimeError(f"TRELLIS failed after {MAX_RETRIES} attempts: {e}")\n', "\n", "# 4. POST GLB back\n", "glb_src = result[1] or result[2]\n", 'print(f"Sending GLB ({os.path.getsize(glb_src)/1024/1024:.1f} MB)...")\n', "\n", "with open(glb_src, 'rb') as f:\n", " resp = requests.post(\n", ' f"{SPACE_URL}/receive_glb",\n', ' params={"job_id": JOB_ID},\n', ' files={"file": ("mesh.glb", f, "model/gltf-binary")},\n', " timeout=120\n", " )\n", "\n", "resp.raise_for_status()\n", 'print(f"GLB delivered! Job {JOB_ID} complete.")\n', ] notebook = { "cells": [ { "cell_type": "code", "execution_count": None, "id": "6117abdc", "metadata": {}, "outputs": [], "source": source, } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3", }, "language_info": {"name": "python", "version": "3.12.0"}, }, "nbformat": 4, "nbformat_minor": 5, } return json.dumps(notebook, indent=1) async def trigger_kaggle(job_id: str): try: jobs[job_id]["status"] = "running" work_dir = tempfile.mkdtemp() # kernel-metadata.json — no env vars needed anymore meta = { "id": KERNEL_REF, "title": "trellis-worker", "code_file": "worker.ipynb", "language": "python", "kernel_type": "notebook", "is_private": True, "enable_gpu": False, "enable_internet": True, "dataset_sources": [], "competition_sources": [], "kernel_sources": [], } with open(f"{work_dir}/kernel-metadata.json", "w") as f: json.dump(meta, f) # Generate notebook with JOB_ID + SPACE_URL baked in notebook_json = make_notebook(job_id, SPACE_URL) with open(f"{work_dir}/worker.ipynb", "w") as f: f.write(notebook_json) env = {**os.environ, "KAGGLE_USERNAME": KAGGLE_USER, "KAGGLE_KEY": KAGGLE_KEY} proc = await asyncio.create_subprocess_exec( "kaggle", "kernels", "push", "-p", work_dir, env=env, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() print(f"[{job_id[:8]}] Kaggle push: {stdout.decode()} {stderr.decode()}") for _ in range(60): # wait up to 15 min await asyncio.sleep(15) if jobs[job_id]["status"] == "done": print(f"[{job_id[:8]}] Done!") return jobs[job_id]["status"] = "error:timeout" except Exception as e: print(f"[{job_id[:8]}] Error: {e}") jobs[job_id]["status"] = f"error:{e}" if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)