Spaces:
Sleeping
Sleeping
| import fastapi | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| import cv2 | |
| import numpy as np | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.utils.visualizer import Visualizer | |
| from detectron2.data import MetadataCatalog | |
| import base64 | |
| app = FastAPI() | |
| # Global variables | |
| predictor = None | |
| metadata = None | |
| async def load_model(): | |
| global predictor, metadata | |
| try: | |
| # Path to model and config | |
| config_path = "mask_rcnn_config.yaml" | |
| model_path = "model_final.pth" | |
| # Initialize Detectron2 config | |
| cfg = get_cfg() | |
| cfg.merge_from_file(config_path) | |
| cfg.MODEL.WEIGHTS = model_path | |
| cfg.MODEL.DEVICE = "cpu" | |
| # Set up class names in metadata | |
| # Replace these with your actual class names | |
| class_names = ["lesion", "light", "mucus"] # Add your class names here | |
| MetadataCatalog.get("medical_train").thing_classes = class_names | |
| predictor = DefaultPredictor(cfg) | |
| metadata = MetadataCatalog.get("medical_train") | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| async def predict_image(file: UploadFile = File(...)): | |
| try: | |
| # Read the image from the file | |
| img_bytes = await file.read() | |
| npimg = np.frombuffer(img_bytes, np.uint8) | |
| image = cv2.imdecode(npimg, cv2.IMREAD_COLOR) | |
| # Make the prediction | |
| outputs = predictor(image) | |
| instances = outputs["instances"].to("cpu") | |
| # Get all prediction information | |
| pred_classes = instances.pred_classes.tolist() | |
| scores = instances.scores.tolist() | |
| masks = instances.pred_masks.numpy() | |
| boxes = instances.pred_boxes.tensor.numpy() | |
| # Convert class indices to class names | |
| class_names = [metadata.thing_classes[idx] for idx in pred_classes] | |
| # Visualize predictions | |
| visualizer = Visualizer(image[:, :, ::-1], metadata, scale=1.2) | |
| output_image = visualizer.draw_instance_predictions(instances).get_image() | |
| # Convert the visualization image to base64 | |
| _, img_encoded = cv2.imencode('.jpg', output_image[:, :, ::-1]) | |
| img_base64 = base64.b64encode(img_encoded).decode('utf-8') | |
| # Prepare the response | |
| response_data = { | |
| "visualization": img_base64, | |
| "predictions": [ | |
| { | |
| "class_name": class_name, | |
| "class_id": class_id, | |
| "score": float(score), | |
| "bbox": box.tolist(), | |
| } | |
| for class_name, class_id, score, box | |
| in zip(class_names, pred_classes, scores, boxes) | |
| ], | |
| } | |
| return JSONResponse(content=response_data) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e)} | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |