Spaces:
Sleeping
Sleeping
| """ | |
| VAE Hugging Face Space app. | |
| Upload a mask image -> encode -> decode -> return one slice (slice 2 of 4). | |
| """ | |
| import base64 | |
| import io | |
| import logging | |
| from fastapi import FastAPI, HTTPException, UploadFile | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import gradio as gr | |
| from PIL import Image | |
| import uvicorn | |
| from inference import inference_to_png, OUTPUT_SLICE_INDEX | |
| # --- Gradio UI --- | |
| def run_inference(mask: Image.Image) -> Image.Image: | |
| if mask is None: | |
| raise gr.Error("Please upload a mask image.") | |
| png_bytes = inference_to_png(mask, slice_index=OUTPUT_SLICE_INDEX) | |
| return Image.open(io.BytesIO(png_bytes)).convert("L") | |
| demo = gr.Interface( | |
| fn=run_inference, | |
| inputs=gr.Image(label="Mask (grayscale)", type="pil"), | |
| outputs=gr.Image(label=f"Output slice {OUTPUT_SLICE_INDEX} of 4"), | |
| title="VAE CT Slice Generator", | |
| description=( | |
| "Upload a **mask** image (grayscale). The model encodes it, decodes to 4 slices (3D CT), " | |
| f"and returns **slice {OUTPUT_SLICE_INDEX}** as a 2D image for the web." | |
| ), | |
| ) | |
| # --- FastAPI app (Gradio mounted at /) --- | |
| app = FastAPI(title="VAE CT Slice API") | |
| # CORS: allow your website (and others) to call /predict and /generate from the browser | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| async def generate(file: UploadFile): | |
| """Upload mask -> return single slice as data URI (same shape as diffusion /generate).""" | |
| try: | |
| raw = await file.read() | |
| logger.info("[request /generate] INPUT: filename=%s content_type=%s raw_bytes=%s", | |
| file.filename, file.content_type, len(raw)) | |
| mask = Image.open(io.BytesIO(raw)).convert("L") | |
| logger.info("[request /generate] image opened: size=%s mode=%s", mask.size, mask.mode) | |
| png_bytes = inference_to_png(mask, slice_index=OUTPUT_SLICE_INDEX) | |
| b64 = base64.b64encode(png_bytes).decode("ascii") | |
| return JSONResponse(content={"image": f"data:image/png;base64,{b64}"}) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def predict(file: UploadFile): | |
| """Upload mask -> return single slice as base64 PNG in JSON.""" | |
| try: | |
| raw = await file.read() | |
| logger.info("[request /predict] INPUT: filename=%s content_type=%s raw_bytes=%s", | |
| file.filename, file.content_type, len(raw)) | |
| mask = Image.open(io.BytesIO(raw)).convert("L") | |
| logger.info("[request /predict] image opened: size=%s mode=%s", mask.size, mask.mode) | |
| png_bytes = inference_to_png(mask, slice_index=OUTPUT_SLICE_INDEX) | |
| b64 = base64.b64encode(png_bytes).decode("ascii") | |
| return JSONResponse(content={"image": b64, "slice_index": OUTPUT_SLICE_INDEX}) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| def root(): | |
| return {"status": "ok", "message": "VAE CT slice API. Use /generate or /predict with a mask image."} | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |