JumaRubea's picture
Update app.py
5caa2af verified
import fastapi
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import cv2
import numpy as np
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import base64
app = FastAPI()
# Global variables
predictor = None
metadata = None
@app.on_event("startup")
async def load_model():
global predictor, metadata
try:
# Path to model and config
config_path = "mask_rcnn_config.yaml"
model_path = "model_final.pth"
# Initialize Detectron2 config
cfg = get_cfg()
cfg.merge_from_file(config_path)
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.DEVICE = "cpu"
# Set up class names in metadata
# Replace these with your actual class names
class_names = ["lesion", "light", "mucus"] # Add your class names here
MetadataCatalog.get("medical_train").thing_classes = class_names
predictor = DefaultPredictor(cfg)
metadata = MetadataCatalog.get("medical_train")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
try:
# Read the image from the file
img_bytes = await file.read()
npimg = np.frombuffer(img_bytes, np.uint8)
image = cv2.imdecode(npimg, cv2.IMREAD_COLOR)
# Make the prediction
outputs = predictor(image)
instances = outputs["instances"].to("cpu")
# Get all prediction information
pred_classes = instances.pred_classes.tolist()
scores = instances.scores.tolist()
masks = instances.pred_masks.numpy()
boxes = instances.pred_boxes.tensor.numpy()
# Convert class indices to class names
class_names = [metadata.thing_classes[idx] for idx in pred_classes]
# Visualize predictions
visualizer = Visualizer(image[:, :, ::-1], metadata, scale=1.2)
output_image = visualizer.draw_instance_predictions(instances).get_image()
# Convert the visualization image to base64
_, img_encoded = cv2.imencode('.jpg', output_image[:, :, ::-1])
img_base64 = base64.b64encode(img_encoded).decode('utf-8')
# Prepare the response
response_data = {
"visualization": img_base64,
"predictions": [
{
"class_name": class_name,
"class_id": class_id,
"score": float(score),
"bbox": box.tolist(),
}
for class_name, class_id, score, box
in zip(class_names, pred_classes, scores, boxes)
],
}
return JSONResponse(content=response_data)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)