InteriorFusion / api /main.py
stevee00's picture
Upload api/main.py
04de76b verified
"""FastAPI backend for InteriorFusion inference service."""
import os
import tempfile
from pathlib import Path
from typing import List, Optional
import torch
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from PIL import Image
import io
import base64
from interiorfusion.pipelines import InteriorFusionPipeline, InteriorFusionOutput
app = FastAPI(
title="InteriorFusion API",
description="Single image to 3D interior scene generation",
version="0.1.0",
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global pipeline instance
_pipeline: Optional[InteriorFusionPipeline] = None
def get_pipeline() -> InteriorFusionPipeline:
global _pipeline
if _pipeline is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
_pipeline = InteriorFusionPipeline(
model_size="L",
device=device,
dtype=torch.float16,
)
return _pipeline
@app.post("/generate")
async def generate_3d_scene(
image: UploadFile = File(...),
room_type: Optional[str] = Form(None),
style: Optional[str] = Form(None),
formats: str = Form("glb,ply"),
model_size: str = Form("L"),
):
"""
Generate a 3D interior scene from a single image.
Returns download links for the generated 3D files.
"""
# Read image
contents = await image.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
# Parse formats
output_formats = [f.strip() for f in formats.split(",")]
# Run pipeline
pipeline = get_pipeline()
output = pipeline(
image=img,
room_type_hint=room_type,
style_hint=style,
)
# Export
output_dir = tempfile.mkdtemp()
output.export_all(output_dir)
# Collect file paths
files = {}
for fmt in output_formats:
path = Path(output_dir) / f"scene.{fmt}"
if path.exists():
files[fmt] = str(path)
return JSONResponse({
"success": True,
"room_type": output.room_type,
"style": output.style,
"processing_time": output.processing_time,
"num_objects": len(output.object_meshes),
"files": files,
})
@app.post("/edit")
async def edit_scene(
scene_glb: UploadFile = File(...),
edit_action: str = Form(...), # "move", "replace", "remove", "add"
object_id: Optional[int] = Form(None),
new_image: Optional[UploadFile] = File(None),
position: Optional[str] = Form(None), # JSON array [x, y, z]
):
"""
Edit an existing scene.
Actions:
- move: Move an existing object
- replace: Replace an object with a new one
- remove: Remove an object
- add: Add a new object
"""
import json
pipeline = get_pipeline()
# Parse position
pos = None
if position:
pos = json.loads(position)
# Build edit dict
edit = {"action": edit_action}
if object_id is not None:
edit["object_id"] = object_id
if pos:
edit["position"] = pos
if new_image:
contents = await new_image.read()
edit["new_image"] = Image.open(io.BytesIO(contents)).convert("RGB")
# For simplicity, return not-implemented
return JSONResponse({
"success": False,
"message": "Scene editing API coming soon",
})
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "ok", "version": "0.1.0"}
@app.get("/download/{filename}")
async def download_file(filename: str):
"""Download a generated file."""
# In production, use proper file storage (S3, etc.)
# For now, placeholder
return JSONResponse({"message": f"Download {filename} from storage"})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)