Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| from transformers import MobileNetV2ForSemanticSegmentation, AutoImageProcessor | |
| import torch | |
| from io import BytesIO | |
| import base64 | |
| import numpy as np | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load processor and model | |
| processor = AutoImageProcessor.from_pretrained("seg_model") | |
| model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model") | |
| async def predict(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| img = Image.open(BytesIO(contents)).convert("RGB") | |
| inputs = processor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits # (batch, num_labels, H, W) | |
| mask = torch.argmax(logits, dim=1)[0].numpy().astype(np.uint8) | |
| # Convert mask to grayscale PNG and return as base64 | |
| mask_img = Image.fromarray(mask) | |
| buf = BytesIO() | |
| mask_img.save(buf, format="PNG") | |
| buf.seek(0) | |
| b64 = base64.b64encode(buf.read()).decode() | |
| return {"success": True, "mask": "data:image/png;base64," + b64} | |