JumaRubea commited on
Commit
57ed140
·
verified ·
1 Parent(s): a54fe83

create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from fastapi.responses import JSONResponse
4
+ import cv2
5
+ import numpy as np
6
+ from detectron2.engine import DefaultPredictor
7
+ from detectron2.config import get_cfg
8
+ from detectron2.utils.visualizer import Visualizer
9
+ from detectron2.data import MetadataCatalog
10
+ import base64
11
+
12
+ app = FastAPI()
13
+
14
+ # Global variables
15
+ predictor = None
16
+ metadata = None
17
+
18
+ @app.on_event("startup")
19
+ async def load_model():
20
+ global predictor, metadata
21
+ try:
22
+ # Path to model and config
23
+ config_path = "mask_rcnn_config.yaml"
24
+ model_path = "model_final.pth"
25
+
26
+ # Initialize Detectron2 config
27
+ cfg = get_cfg()
28
+ cfg.merge_from_file(config_path)
29
+ cfg.MODEL.WEIGHTS = model_path
30
+ cfg.MODEL.DEVICE = "cpu"
31
+
32
+ # Set up class names in metadata
33
+ # Replace these with your actual class names
34
+ class_names = ["lesion", "light", "mucus"] # Add your class names here
35
+ MetadataCatalog.get("medical_train").thing_classes = class_names
36
+
37
+ predictor = DefaultPredictor(cfg)
38
+ metadata = MetadataCatalog.get("medical_train")
39
+ print("Model loaded successfully.")
40
+ except Exception as e:
41
+ print(f"Error loading model: {e}")
42
+
43
+ @app.post("/predict")
44
+ async def predict_image(file: UploadFile = File(...)):
45
+ try:
46
+ # Read the image from the file
47
+ img_bytes = await file.read()
48
+ npimg = np.frombuffer(img_bytes, np.uint8)
49
+ image = cv2.imdecode(npimg, cv2.IMREAD_COLOR)
50
+
51
+ # Make the prediction
52
+ outputs = predictor(image)
53
+ instances = outputs["instances"].to("cpu")
54
+
55
+ # Get all prediction information
56
+ pred_classes = instances.pred_classes.tolist()
57
+ scores = instances.scores.tolist()
58
+ masks = instances.pred_masks.numpy()
59
+ boxes = instances.pred_boxes.tensor.numpy()
60
+
61
+ # Convert class indices to class names
62
+ class_names = [metadata.thing_classes[idx] for idx in pred_classes]
63
+
64
+ # Visualize predictions
65
+ visualizer = Visualizer(image[:, :, ::-1], metadata, scale=1.2)
66
+ output_image = visualizer.draw_instance_predictions(instances).get_image()
67
+
68
+ # Convert the visualization image to base64
69
+ _, img_encoded = cv2.imencode('.jpg', output_image[:, :, ::-1])
70
+ img_base64 = base64.b64encode(img_encoded).decode('utf-8')
71
+
72
+ # Prepare the response
73
+ response_data = {
74
+ "visualization": img_base64,
75
+ "predictions": [
76
+ {
77
+ "class_name": class_name,
78
+ "class_id": class_id,
79
+ "score": float(score),
80
+ "bbox": box.tolist(),
81
+ }
82
+ for class_name, class_id, score, box
83
+ in zip(class_names, pred_classes, scores, boxes)
84
+ ],
85
+ }
86
+
87
+ return JSONResponse(content=response_data)
88
+
89
+ except Exception as e:
90
+ return JSONResponse(
91
+ status_code=500,
92
+ content={"error": str(e)}
93
+ )
94
+
95
+ if __name__ == "__main__":
96
+ import uvicorn
97
+ uvicorn.run(app, host="0.0.0.0", port=5000)