Noursine commited on
Commit
92764bf
·
verified ·
1 Parent(s): 93bf16f

Create meaw.py

Browse files
Files changed (1) hide show
  1. meaw.py +116 -0
meaw.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import gdown
5
+ import base64
6
+ from fastapi import FastAPI, UploadFile, File
7
+ from fastapi.responses import JSONResponse
8
+ from detectron2.config import get_cfg
9
+ from detectron2 import model_zoo
10
+ from detectron2.engine import DefaultPredictor
11
+ from detectron2.utils.visualizer import Visualizer, ColorMode
12
+ from detectron2.data import MetadataCatalog
13
+ import io
14
+ # === Model Setup ===
15
+ MODEL_PATH = "/tmp/model_final.pth"
16
+ DRIVE_FILE_ID = "1bazIVYG0CYMubDLoMu5pgH6ArC1sayzg"
17
+
18
+ def download_model():
19
+ if not os.path.exists(MODEL_PATH):
20
+ url = f"https://drive.google.com/uc?id={DRIVE_FILE_ID}"
21
+
22
+ # Create a writable temporary directory for gdown
23
+ tmp_dir = "/tmp/gdown"
24
+ os.makedirs(tmp_dir, exist_ok=True)
25
+
26
+ # Set GDOWN cache path to a safe location
27
+ os.environ["GDOWN_CACHE_DIR"] = tmp_dir
28
+
29
+ print("Downloading Detectron2 model...")
30
+ gdown.download(
31
+ url,
32
+ MODEL_PATH,
33
+ quiet=False,
34
+ fuzzy=True,
35
+ use_cookies=False # Important for Hugging Face
36
+ )
37
+ print("Download complete.")
38
+
39
+ download_model()
40
+
41
+
42
+ # Register dummy metadata (for visualization only)
43
+ MetadataCatalog.get(metadata_name).set(thing_classes=class_names)
44
+
45
+ # === Config Setup ===
46
+ cfg = get_cfg()
47
+ cfg.merge_from_file(
48
+ model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
49
+ )
50
+ cfg.MODEL.WEIGHTS = MODEL_PATH
51
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
52
+ cfg.MODEL.DEVICE = "cpu" # or "cuda" if GPU is enabled
53
+
54
+ # === Predictor ===
55
+ predictor = DefaultPredictor(cfg)
56
+ # -----------------------------
57
+ # 3. Helper: Encode mask to Base64
58
+ # -----------------------------
59
+ def encode_mask(mask: np.ndarray) -> str:
60
+ """Convert mask numpy array to base64 PNG string."""
61
+ mask_img = Image.fromarray(mask.astype(np.uint8))
62
+ buf = io.BytesIO()
63
+ mask_img.save(buf, format="PNG")
64
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
65
+ # === Predict Function ===
66
+ def predict(image_path: str):
67
+ im = cv2.imread(image_path)
68
+ outputs = predictor(im)
69
+ instances = outputs["instances"].to("cpu")
70
+
71
+ # Visualization
72
+ v = Visualizer(
73
+ im[:, :, ::-1],
74
+ metadata=metadata,
75
+ scale=1.2,
76
+ instance_mode=ColorMode.IMAGE
77
+ )
78
+ out = v.draw_instance_predictions(instances)
79
+ result_img = out.get_image()[:, :, ::-1]
80
+
81
+ # -----------------------------
82
+ # 4. API Endpoint
83
+ # -----------------------------
84
+ @app.post("/predict")
85
+ async def predict(file: UploadFile = File(...)):
86
+ # Read image
87
+ contents = await file.read()
88
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
89
+ image = np.array(image)[:, :, ::-1] # to BGR for OpenCV/Detectron2
90
+
91
+ # Run inference
92
+ outputs = predictor(image)
93
+ instances = outputs["instances"].to("cpu")
94
+
95
+ results = []
96
+ mask_b64 = None
97
+
98
+ if instances.has("pred_masks"):
99
+ masks = instances.pred_masks.numpy()
100
+ boxes = instances.pred_boxes.tensor.numpy()
101
+ scores = instances.scores.numpy()
102
+
103
+ # Combine masks into one
104
+ combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255
105
+ mask_b64 = encode_mask(combined_mask)
106
+
107
+ for i in range(len(masks)):
108
+ results.append({
109
+ "box": boxes[i].tolist(),
110
+ "score": float(scores[i])
111
+ })
112
+
113
+ return JSONResponse({
114
+ "predictions": results,
115
+ "mask": mask_b64 # base64 string (PNG)
116
+ })