File size: 3,150 Bytes
57ed140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5caa2af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import fastapi
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import cv2
import numpy as np
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import base64

app = FastAPI()

# Global variables
predictor = None
metadata = None

@app.on_event("startup")
async def load_model():
    global predictor, metadata
    try:
        # Path to model and config
        config_path = "mask_rcnn_config.yaml"
        model_path = "model_final.pth"

        # Initialize Detectron2 config
        cfg = get_cfg()
        cfg.merge_from_file(config_path)
        cfg.MODEL.WEIGHTS = model_path
        cfg.MODEL.DEVICE = "cpu"
        
        # Set up class names in metadata
        # Replace these with your actual class names
        class_names = ["lesion", "light", "mucus"]  # Add your class names here
        MetadataCatalog.get("medical_train").thing_classes = class_names
        
        predictor = DefaultPredictor(cfg)
        metadata = MetadataCatalog.get("medical_train")
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Error loading model: {e}")

@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
    try:
        # Read the image from the file
        img_bytes = await file.read()
        npimg = np.frombuffer(img_bytes, np.uint8)
        image = cv2.imdecode(npimg, cv2.IMREAD_COLOR)

        # Make the prediction
        outputs = predictor(image)
        instances = outputs["instances"].to("cpu")

        # Get all prediction information
        pred_classes = instances.pred_classes.tolist()
        scores = instances.scores.tolist()
        masks = instances.pred_masks.numpy()
        boxes = instances.pred_boxes.tensor.numpy()

        # Convert class indices to class names
        class_names = [metadata.thing_classes[idx] for idx in pred_classes]

        # Visualize predictions
        visualizer = Visualizer(image[:, :, ::-1], metadata, scale=1.2)
        output_image = visualizer.draw_instance_predictions(instances).get_image()

        # Convert the visualization image to base64
        _, img_encoded = cv2.imencode('.jpg', output_image[:, :, ::-1])
        img_base64 = base64.b64encode(img_encoded).decode('utf-8')

        # Prepare the response
        response_data = {
            "visualization": img_base64,
            "predictions": [
                {
                    "class_name": class_name,
                    "class_id": class_id,
                    "score": float(score),
                    "bbox": box.tolist(),
                }
                for class_name, class_id, score, box 
                in zip(class_names, pred_classes, scores, boxes)
            ],
        }

        return JSONResponse(content=response_data)

    except Exception as e:
        return JSONResponse(
            status_code=500,
            content={"error": str(e)}
        )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)