3D_Model_AI / main.py
everydaycats's picture
Update main.py
ec7e3da verified
# app.py
import os
import uuid
import shutil
import logging
import requests
import asyncio
import time
from typing import Optional, Dict, Any
from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Form
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from huggingface_hub import login
from app.utils import run_inference
# --- Configuration / env ---
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
try:
login(token=hf_token)
except Exception:
# Non-fatal if login fails in some deployments
pass
TMP_DIR = os.environ.get("TMP_DIR", "/app/tmp")
os.makedirs(TMP_DIR, exist_ok=True)
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("stable-fast-3d-api")
app = FastAPI(title="Stable Fast 3D API (Background Jobs)")
# In-memory job registry
# Structure:
# JOBS[request_id] = {
# "status": "pending" | "running" | "done" | "error",
# "input_path": "...",
# "output_dir": "...",
# "glb_path": Optional[str],
# "error": Optional[str],
# "created_at": float,
# "started_at": Optional[float],
# "finished_at": Optional[float],
# }
JOBS: Dict[str, Dict[str, Any]] = {}
JOBS_LOCK = asyncio.Lock()
# -------------------------
# Utility helpers
# -------------------------
def _save_upload_file(upload_file: UploadFile, dest_path: str) -> None:
with open(dest_path, "wb") as f:
shutil.copyfileobj(upload_file.file, f)
upload_file.file.close()
def _download_to_file(url: str, dest_path: str, timeout: int = 30) -> None:
resp = requests.get(url, stream=True, timeout=timeout)
if resp.status_code != 200:
raise HTTPException(status_code=400, detail=f"Failed to download image: status {resp.status_code}")
with open(dest_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
if not chunk:
continue
f.write(chunk)
def _find_glb_in_dir(output_dir: str) -> Optional[str]:
for root, _, files in os.walk(output_dir):
for fn in files:
if fn.lower().endswith(".glb"):
return os.path.join(root, fn)
return None
async def _set_job_field(job_id: str, key: str, value):
async with JOBS_LOCK:
if job_id in JOBS:
JOBS[job_id][key] = value
async def _get_job(job_id: str):
async with JOBS_LOCK:
return JOBS.get(job_id)
# -------------------------
# Background worker
# -------------------------
async def _background_run_inference(job_id: str):
"""Runs run_inference in a thread to avoid blocking the event loop."""
job = await _get_job(job_id)
if not job:
logger.error("Job not found when starting background task: %s", job_id)
return
input_path = job["input_path"]
output_dir = job["output_dir"]
logger.info("[%s] Background job starting. input=%s output=%s", job_id, input_path, output_dir)
await _set_job_field(job_id, "status", "running")
await _set_job_field(job_id, "started_at", time.time())
try:
# run_inference is synchronous / heavy — move to thread
glb_path = await asyncio.to_thread(run_inference, input_path, output_dir)
# If run_inference returned None or not a path, try to discover a .glb
if not glb_path or not os.path.exists(glb_path):
found = _find_glb_in_dir(output_dir)
if found:
glb_path = found
if not glb_path or not os.path.exists(glb_path):
# List files for debugging
listing = []
for root, _, files in os.walk(output_dir):
for fn in files:
listing.append(os.path.join(root, fn))
raise RuntimeError(f"GLB not produced. output_dir listing: {listing}")
# Mark success
await _set_job_field(job_id, "glb_path", glb_path)
await _set_job_field(job_id, "status", "done")
await _set_job_field(job_id, "finished_at", time.time())
logger.info("[%s] Background job finished successfully. glb=%s", job_id, glb_path)
except Exception as e:
logger.exception("[%s] Background inference failed: %s", job_id, e)
await _set_job_field(job_id, "status", "error")
await _set_job_field(job_id, "error", str(e))
await _set_job_field(job_id, "finished_at", time.time())
# -------------------------
# Embedded UI root (polling-based)
# -------------------------
@app.get("/", response_class=HTMLResponse)
async def root_ui():
html = """
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<title>Stable Fast 3D API — Background Jobs</title>
<meta name="viewport" content="width=device-width,initial-scale=1" />
<script type="module" src="https://unpkg.com/@google/model-viewer/dist/model-viewer.min.js"></script>
<style>
body { font-family: Inter, system-ui, -apple-system, "Segoe UI", Roboto, "Helvetica Neue", Arial; padding: 24px; background:#f7f8fb; color:#111; }
.container { max-width:900px; margin:0 auto; background:white; padding:20px; border-radius:12px; box-shadow:0 6px 20px rgba(10,10,20,0.06); }
h1 { margin:0 0 12px 0; font-size:20px; }
label { display:block; margin-top:12px; font-weight:600; }
input[type="text"], input[type="url"] { width:100%; padding:8px 10px; margin-top:6px; border-radius:8px; border:1px solid #e2e8f0; }
.row { display:flex; gap:10px; align-items:center; margin-top:12px; }
button { padding:10px 14px; border-radius:8px; border: none; cursor:pointer; background:#111827; color:white; font-weight:600; }
.status { margin-top:12px; font-size:14px; color:#374151; }
model-viewer { width:100%; height:480px; background: #e6edf3; border-radius:8px; margin-top:16px; display:none; }
.download-link { margin-top:12px; display:block; }
</style>
</head>
<body>
<div class="container">
<h1>Stable Fast 3D API — Background Jobs</h1>
<p>Upload an image or paste an image URL to generate a 3D model (GLB). The job runs server-side and continues even if you close this page.</p>
<form id="generateForm">
<label for="fileInput">Upload image file</label>
<input id="fileInput" name="image" type="file" accept="image/*" />
<label for="urlInput">Or provide image URL</label>
<input id="urlInput" name="image_url" type="url" placeholder="https://example.com/image.png" />
<div class="row">
<button id="submitBtn" type="submit">Start Job</button>
<button id="clearBtn" type="button">Clear</button>
</div>
</form>
<div class="status" id="status">Status: idle</div>
<div id="jobArea" style="display:none;">
<p>Job ID: <code id="jobId"></code></p>
<p id="jobStatus">Waiting...</p>
<button id="downloadBtn" style="display:none;">Download GLB</button>
<button id="deleteBtn" style="display:none;">Delete Job & Files</button>
</div>
<model-viewer id="preview" camera-controls auto-rotate environment-image="neutral" style="display:none;"></model-viewer>
</div>
<script>
const form = document.getElementById('generateForm');
const fileInput = document.getElementById('fileInput');
const urlInput = document.getElementById('urlInput');
const status = document.getElementById('status');
const jobArea = document.getElementById('jobArea');
const jobIdEl = document.getElementById('jobId');
const jobStatusEl = document.getElementById('jobStatus');
const downloadBtn = document.getElementById('downloadBtn');
const deleteBtn = document.getElementById('deleteBtn');
const preview = document.getElementById('preview');
const submitBtn = document.getElementById('submitBtn');
const clearBtn = document.getElementById('clearBtn');
let pollInterval = null;
let currentJobId = null;
clearBtn.addEventListener('click', () => {
fileInput.value = '';
urlInput.value = '';
status.textContent = 'Status: idle';
jobArea.style.display = 'none';
preview.style.display = 'none';
});
form.addEventListener('submit', async (e) => {
e.preventDefault();
submitBtn.disabled = true;
status.textContent = 'Status: starting job...';
const hasFile = fileInput.files && fileInput.files.length > 0;
const hasUrl = urlInput.value && urlInput.value.trim().length > 0;
if (!hasFile && !hasUrl) {
status.textContent = 'Status: Please upload a file or provide an image URL.';
submitBtn.disabled = false;
return;
}
const formData = new FormData();
if (hasFile) formData.append('image', fileInput.files[0]);
else formData.append('image_url', urlInput.value.trim());
try {
const resp = await fetch('/generate-3d/', { method: 'POST', body: formData });
if (!resp.ok) {
const txt = await resp.text();
throw new Error('Server error: ' + resp.status + ' ' + txt);
}
const data = await resp.json();
const id = data.id;
currentJobId = id;
jobIdEl.textContent = id;
jobArea.style.display = 'block';
status.textContent = 'Status: job started: ' + id;
pollStatus(id);
pollInterval = setInterval(() => pollStatus(id), 5000);
} catch (err) {
console.error(err);
status.textContent = 'Error starting job: ' + (err.message || err);
} finally {
submitBtn.disabled = false;
}
});
async function pollStatus(id) {
jobStatusEl.textContent = 'Checking...';
try {
const resp = await fetch(`/status/${id}`);
if (resp.status === 404) {
jobStatusEl.textContent = 'Job not found';
return;
}
const data = await resp.json();
jobStatusEl.textContent = 'Status: ' + data.status + (data.error ? ' — ' + data.error : '');
if (data.status === 'done') {
clearInterval(pollInterval);
downloadBtn.style.display = 'inline-block';
deleteBtn.style.display = 'inline-block';
jobStatusEl.textContent += ' — ready';
// enable preview + download
downloadBtn.onclick = () => downloadGLB(id);
deleteBtn.onclick = () => deleteJob(id);
} else if (data.status === 'error') {
clearInterval(pollInterval);
deleteBtn.style.display = 'inline-block';
}
} catch (err) {
console.error('poll error', err);
jobStatusEl.textContent = 'Status: poll error';
}
}
async function downloadGLB(id) {
try {
const resp = await fetch(`/download/${id}`);
if (!resp.ok) {
const txt = await resp.text();
throw new Error('Download failed: ' + resp.status + ' ' + txt);
}
const blob = await resp.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'model_' + id + '.glb';
document.body.appendChild(a);
a.click();
a.remove();
// preview with model-viewer
preview.src = url;
preview.style.display = 'block';
setTimeout(() => URL.revokeObjectURL(url), 5 * 60 * 1000);
} catch (err) {
console.error(err);
alert('Download failed: ' + err.message);
}
}
async function deleteJob(id) {
if (!confirm('Delete job and all stored files? This is irreversible.')) return;
try {
const resp = await fetch(`/delete/${id}`, { method: 'DELETE' });
if (!resp.ok) {
const txt = await resp.text();
throw new Error('Delete failed: ' + resp.status + ' ' + txt);
}
alert('Deleted job ' + id);
jobArea.style.display = 'none';
preview.style.display = 'none';
} catch (err) {
console.error(err);
alert('Delete failed: ' + err.message);
}
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html, status_code=200)
# -------------------------
# API: Start job (non-blocking)
# -------------------------
@app.post("/generate-3d/")
async def generate_3d_start(
image: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
):
"""
Start a background job to generate a 3D model.
Returns JSON: { "id": "<job_id>", "status_url": "/status/<id>", "download_url": "/download/<id>" }
"""
request_id = str(uuid.uuid4())
input_path = os.path.join(TMP_DIR, f"{request_id}.png")
output_dir = os.path.join(TMP_DIR, f"{request_id}_output")
os.makedirs(output_dir, exist_ok=True)
# Save input
try:
if image is not None:
_save_upload_file(image, input_path)
elif image_url:
_download_to_file(image_url, input_path, timeout=30)
else:
raise HTTPException(status_code=400, detail="Either image or image_url must be provided")
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to save input for job %s: %s", request_id, e)
raise HTTPException(status_code=500, detail=f"Failed to save input: {e}")
# Register job (pending)
async with JOBS_LOCK:
JOBS[request_id] = {
"status": "pending",
"input_path": input_path,
"output_dir": output_dir,
"glb_path": None,
"error": None,
"created_at": time.time(),
"started_at": None,
"finished_at": None,
}
# Kick off background task (does not block the request)
asyncio.create_task(_background_run_inference(request_id))
logger.info("Started background job %s", request_id)
return JSONResponse({
"id": request_id,
"status_url": f"/status/{request_id}",
"download_url": f"/download/{request_id}",
})
# -------------------------
# API: Check status
# -------------------------
@app.get("/status/{job_id}")
async def job_status(job_id: str):
job = await _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# return the public fields
return JSONResponse({
"id": job_id,
"status": job["status"],
"glb_path": bool(job.get("glb_path")),
"error": job.get("error"),
"created_at": job.get("created_at"),
"started_at": job.get("started_at"),
"finished_at": job.get("finished_at"),
})
# -------------------------
# API: Download result (if ready)
# -------------------------
@app.get("/download/{job_id}")
async def download_result(job_id: str):
job = await _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job["status"] != "done" or not job.get("glb_path"):
# Not ready
raise HTTPException(status_code=404, detail="Result not ready")
glb_path = job["glb_path"]
if not os.path.exists(glb_path):
raise HTTPException(status_code=404, detail="GLB file missing on disk")
# Return FileResponse without deleting it (user must call DELETE to remove)
return FileResponse(path=glb_path, media_type="model/gltf-binary", filename=os.path.basename(glb_path))
# -------------------------
# API: Delete job & files (manual)
# -------------------------
@app.delete("/delete/{job_id}")
async def delete_job(job_id: str):
job = await _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# Remove files
input_path = job.get("input_path")
output_dir = job.get("output_dir")
glb_path = job.get("glb_path")
errors = []
try:
if input_path and os.path.exists(input_path):
os.remove(input_path)
except Exception as e:
errors.append(f"input removal error: {e}")
try:
if output_dir and os.path.exists(output_dir):
shutil.rmtree(output_dir, ignore_errors=True)
except Exception as e:
errors.append(f"output dir removal error: {e}")
# Remove job entry
async with JOBS_LOCK:
JOBS.pop(job_id, None)
if errors:
logger.warning("Delete job %s completed with errors: %s", job_id, errors)
return JSONResponse({"deleted": True, "errors": errors})
return JSONResponse({"deleted": True})
# -------------------------
# API: List jobs (optional)
# -------------------------
@app.get("/jobs")
async def list_jobs():
async with JOBS_LOCK:
out = {
jid: {
"status": j["status"],
"created_at": j["created_at"],
"started_at": j["started_at"],
"finished_at": j["finished_at"],
"has_glb": bool(j.get("glb_path")),
}
for jid, j in JOBS.items()
}
return JSONResponse(out)