import fastapi from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse import cv2 import numpy as np from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer from detectron2.data import MetadataCatalog import base64 app = FastAPI() # Global variables predictor = None metadata = None @app.on_event("startup") async def load_model(): global predictor, metadata try: # Path to model and config config_path = "mask_rcnn_config.yaml" model_path = "model_final.pth" # Initialize Detectron2 config cfg = get_cfg() cfg.merge_from_file(config_path) cfg.MODEL.WEIGHTS = model_path cfg.MODEL.DEVICE = "cpu" # Set up class names in metadata # Replace these with your actual class names class_names = ["lesion", "light", "mucus"] # Add your class names here MetadataCatalog.get("medical_train").thing_classes = class_names predictor = DefaultPredictor(cfg) metadata = MetadataCatalog.get("medical_train") print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") @app.post("/predict") async def predict_image(file: UploadFile = File(...)): try: # Read the image from the file img_bytes = await file.read() npimg = np.frombuffer(img_bytes, np.uint8) image = cv2.imdecode(npimg, cv2.IMREAD_COLOR) # Make the prediction outputs = predictor(image) instances = outputs["instances"].to("cpu") # Get all prediction information pred_classes = instances.pred_classes.tolist() scores = instances.scores.tolist() masks = instances.pred_masks.numpy() boxes = instances.pred_boxes.tensor.numpy() # Convert class indices to class names class_names = [metadata.thing_classes[idx] for idx in pred_classes] # Visualize predictions visualizer = Visualizer(image[:, :, ::-1], metadata, scale=1.2) output_image = visualizer.draw_instance_predictions(instances).get_image() # Convert the visualization image to base64 _, img_encoded = cv2.imencode('.jpg', output_image[:, :, ::-1]) img_base64 = base64.b64encode(img_encoded).decode('utf-8') # Prepare the response response_data = { "visualization": img_base64, "predictions": [ { "class_name": class_name, "class_id": class_id, "score": float(score), "bbox": box.tolist(), } for class_name, class_id, score, box in zip(class_names, pred_classes, scores, boxes) ], } return JSONResponse(content=response_data) except Exception as e: return JSONResponse( status_code=500, content={"error": str(e)} ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)