Spaces:
Sleeping
Sleeping
File size: 3,150 Bytes
57ed140 5caa2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import fastapi
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import cv2
import numpy as np
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
import base64
app = FastAPI()
# Global variables
predictor = None
metadata = None
@app.on_event("startup")
async def load_model():
global predictor, metadata
try:
# Path to model and config
config_path = "mask_rcnn_config.yaml"
model_path = "model_final.pth"
# Initialize Detectron2 config
cfg = get_cfg()
cfg.merge_from_file(config_path)
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.DEVICE = "cpu"
# Set up class names in metadata
# Replace these with your actual class names
class_names = ["lesion", "light", "mucus"] # Add your class names here
MetadataCatalog.get("medical_train").thing_classes = class_names
predictor = DefaultPredictor(cfg)
metadata = MetadataCatalog.get("medical_train")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
try:
# Read the image from the file
img_bytes = await file.read()
npimg = np.frombuffer(img_bytes, np.uint8)
image = cv2.imdecode(npimg, cv2.IMREAD_COLOR)
# Make the prediction
outputs = predictor(image)
instances = outputs["instances"].to("cpu")
# Get all prediction information
pred_classes = instances.pred_classes.tolist()
scores = instances.scores.tolist()
masks = instances.pred_masks.numpy()
boxes = instances.pred_boxes.tensor.numpy()
# Convert class indices to class names
class_names = [metadata.thing_classes[idx] for idx in pred_classes]
# Visualize predictions
visualizer = Visualizer(image[:, :, ::-1], metadata, scale=1.2)
output_image = visualizer.draw_instance_predictions(instances).get_image()
# Convert the visualization image to base64
_, img_encoded = cv2.imencode('.jpg', output_image[:, :, ::-1])
img_base64 = base64.b64encode(img_encoded).decode('utf-8')
# Prepare the response
response_data = {
"visualization": img_base64,
"predictions": [
{
"class_name": class_name,
"class_id": class_id,
"score": float(score),
"bbox": box.tolist(),
}
for class_name, class_id, score, box
in zip(class_names, pred_classes, scores, boxes)
],
}
return JSONResponse(content=response_data)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |