Tiffany Degbotse
add app
a79f121
raw
history blame contribute delete
877 Bytes
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import io, base64
from PIL import Image
from model import predict_galaxy
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
overlay, result_text = predict_galaxy(image)
buf = io.BytesIO()
overlay.save(buf, format="PNG")
predicted_class = result_text.split("\n")[0].split(": ")[1]
probability = float(
result_text.split("\n")[1].split(": ")[1].replace("%", "")
) / 100
return {
"class": predicted_class,
"probability": probability,
"heatmap": base64.b64encode(buf.getvalue()).decode()
}