tan200224's picture
Update app.py
895abe6 verified
"""
VAE Hugging Face Space app.
Upload a mask image -> encode -> decode -> return one slice (slice 2 of 4).
"""
import base64
import io
import logging
from fastapi import FastAPI, HTTPException, UploadFile
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import gradio as gr
from PIL import Image
import uvicorn
from inference import inference_to_png, OUTPUT_SLICE_INDEX
# --- Gradio UI ---
def run_inference(mask: Image.Image) -> Image.Image:
if mask is None:
raise gr.Error("Please upload a mask image.")
png_bytes = inference_to_png(mask, slice_index=OUTPUT_SLICE_INDEX)
return Image.open(io.BytesIO(png_bytes)).convert("L")
demo = gr.Interface(
fn=run_inference,
inputs=gr.Image(label="Mask (grayscale)", type="pil"),
outputs=gr.Image(label=f"Output slice {OUTPUT_SLICE_INDEX} of 4"),
title="VAE CT Slice Generator",
description=(
"Upload a **mask** image (grayscale). The model encodes it, decodes to 4 slices (3D CT), "
f"and returns **slice {OUTPUT_SLICE_INDEX}** as a 2D image for the web."
),
)
# --- FastAPI app (Gradio mounted at /) ---
app = FastAPI(title="VAE CT Slice API")
# CORS: allow your website (and others) to call /predict and /generate from the browser
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["*"],
)
@app.post("/generate")
async def generate(file: UploadFile):
"""Upload mask -> return single slice as data URI (same shape as diffusion /generate)."""
try:
raw = await file.read()
logger.info("[request /generate] INPUT: filename=%s content_type=%s raw_bytes=%s",
file.filename, file.content_type, len(raw))
mask = Image.open(io.BytesIO(raw)).convert("L")
logger.info("[request /generate] image opened: size=%s mode=%s", mask.size, mask.mode)
png_bytes = inference_to_png(mask, slice_index=OUTPUT_SLICE_INDEX)
b64 = base64.b64encode(png_bytes).decode("ascii")
return JSONResponse(content={"image": f"data:image/png;base64,{b64}"})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/predict")
async def predict(file: UploadFile):
"""Upload mask -> return single slice as base64 PNG in JSON."""
try:
raw = await file.read()
logger.info("[request /predict] INPUT: filename=%s content_type=%s raw_bytes=%s",
file.filename, file.content_type, len(raw))
mask = Image.open(io.BytesIO(raw)).convert("L")
logger.info("[request /predict] image opened: size=%s mode=%s", mask.size, mask.mode)
png_bytes = inference_to_png(mask, slice_index=OUTPUT_SLICE_INDEX)
b64 = base64.b64encode(png_bytes).decode("ascii")
return JSONResponse(content={"image": b64, "slice_index": OUTPUT_SLICE_INDEX})
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/")
def root():
return {"status": "ok", "message": "VAE CT slice API. Use /generate or /predict with a mask image."}
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)