Subh775 commited on
Commit
db12d23
·
verified ·
1 Parent(s): b3c8c6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import tempfile
5
+ import threading
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import numpy as np
8
+ from flask import Flask, request, jsonify, send_from_directory
9
+ import requests
10
+
11
+ # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
+
14
+ # --- model import (ensure rfdetr package is available in requirements) ---
15
+ try:
16
+ from rfdetr import RFDETRSegPreview
17
+ except Exception as e:
18
+ raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e
19
+
20
+ app = Flask(__name__, static_folder="static", static_url_path="/")
21
+
22
+ # HF checkpoint raw resolve URL (use the 'resolve/main' raw link)
23
+ CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-2/resolve/main/checkpoint_best_total.pth"
24
+ CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
25
+
26
+ MODEL_LOCK = threading.Lock()
27
+ MODEL = None
28
+
29
+ def download_file(url: str, dst: str):
30
+ if os.path.exists(dst):
31
+ return dst
32
+ print(f"[INFO] Downloading weights from {url} ...")
33
+ r = requests.get(url, stream=True, timeout=60)
34
+ r.raise_for_status()
35
+ with open(dst, "wb") as fh:
36
+ for chunk in r.iter_content(chunk_size=8192):
37
+ if chunk:
38
+ fh.write(chunk)
39
+ print("[INFO] Download complete.")
40
+ return dst
41
+
42
+ def init_model():
43
+ global MODEL
44
+ with MODEL_LOCK:
45
+ if MODEL is None:
46
+ # Ensure model checkpoint
47
+ try:
48
+ download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
49
+ except Exception as e:
50
+ print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.")
51
+ # continue; model may fallback to default weights
52
+ print("[INFO] Loading RF-DETR model (CPU mode)...")
53
+ MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
54
+ try:
55
+ MODEL.optimize_for_inference()
56
+ except Exception:
57
+ # optimization may fail on CPU or if not implemented; ignore
58
+ pass
59
+ print("[INFO] Model ready.")
60
+ return MODEL
61
+
62
+ @app.route("/")
63
+ def index():
64
+ return send_from_directory("static", "index.html")
65
+
66
+ def decode_data_url(data_url: str) -> Image.Image:
67
+ if data_url.startswith("data:"):
68
+ header, b64 = data_url.split(",", 1)
69
+ data = base64.b64decode(b64)
70
+ return Image.open(io.BytesIO(data)).convert("RGB")
71
+ else:
72
+ # assume plain base64 or path
73
+ data = base64.b64decode(data_url)
74
+ return Image.open(io.BytesIO(data)).convert("RGB")
75
+
76
+ def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"):
77
+ buf = io.BytesIO()
78
+ pil_img.save(buf, format=fmt)
79
+ b = base64.b64encode(buf.getvalue()).decode("ascii")
80
+ return f"data:image/{fmt.lower()};base64,{b}"
81
+
82
+ def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.25, mask_color=(255,77,166), alpha=0.45):
83
+ """
84
+ masks: either list of HxW bool arrays or numpy array (N,H,W)
85
+ confidences: list of floats
86
+ Returns annotated PIL image and list of kept confidences and count.
87
+ """
88
+ base = pil_img.convert("RGBA")
89
+ W, H = base.size
90
+
91
+ # Normalize masks to N,H,W
92
+ if masks is None:
93
+ return base, []
94
+
95
+ if isinstance(masks, list):
96
+ masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
97
+ else:
98
+ masks_arr = np.asarray(masks)
99
+ # masks might be (H,W,N) -> transpose
100
+ if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
101
+ masks_arr = masks_arr.transpose(2, 0, 1)
102
+
103
+ # create overlay
104
+ overlay = Image.new("RGBA", (W, H), (0,0,0,0))
105
+ draw = ImageDraw.Draw(overlay)
106
+
107
+ kept_confidences = []
108
+ for i in range(masks_arr.shape[0]):
109
+ conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0
110
+ if conf < threshold:
111
+ continue
112
+ mask = masks_arr[i].astype(np.uint8) * 255
113
+ mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST)
114
+ # create colored mask image
115
+ color_layer = Image.new("RGBA", (W,H), mask_color + (0,))
116
+ # put alpha using mask
117
+ color_layer.putalpha(mask_img.point(lambda p: int(p * alpha)))
118
+ overlay = Image.alpha_composite(overlay, color_layer)
119
+ kept_confidences.append(conf)
120
+
121
+ # composite
122
+ annotated = Image.alpha_composite(base, overlay)
123
+
124
+ # add confidence text (show highest kept confidence)
125
+ if len(kept_confidences) > 0:
126
+ best = max(kept_confidences)
127
+ draw = ImageDraw.Draw(annotated)
128
+ try:
129
+ # Try to use a builtin font
130
+ font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30))
131
+ except Exception:
132
+ font = ImageFont.load_default()
133
+ text = f"Confidence: {best:.2f}"
134
+ # draw background box for text
135
+ tw, th = draw.textsize(text, font=font)
136
+ pad = 8
137
+ draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180))
138
+ draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255))
139
+ return annotated.convert("RGB"), kept_confidences
140
+
141
+ @app.route("/predict", methods=["POST"])
142
+ def predict():
143
+ payload = request.get_json(force=True)
144
+ if not payload or "image" not in payload:
145
+ return jsonify({"error": "Missing image"}), 400
146
+ conf = float(payload.get("conf", 0.25))
147
+
148
+ # ensure model ready
149
+ model = init_model()
150
+
151
+ # decode image
152
+ try:
153
+ pil = decode_data_url(payload["image"])
154
+ except Exception as e:
155
+ return jsonify({"error": f"Invalid image: {e}"}), 400
156
+
157
+ # perform prediction (model.predict expects PIL image)
158
+ try:
159
+ detections = model.predict(pil, threshold=0.0) # we filter using conf manually
160
+ except Exception as e:
161
+ return jsonify({"error": f"Inference failure: {e}"}), 500
162
+
163
+ # extract masks and confidences
164
+ masks = getattr(detections, "masks", None)
165
+ confidences = []
166
+ # attempt to read per-instance confidence
167
+ try:
168
+ confidences = [float(x) for x in getattr(detections, "confidence", [])]
169
+ except Exception:
170
+ # fallback: attempt attribute 'scores' or 'scores_' or generate ones
171
+ confidences = []
172
+ try:
173
+ confidences = [float(x) for x in getattr(detections, "scores", [])]
174
+ except Exception:
175
+ confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0)
176
+
177
+ # overlay mask with pink-red color
178
+ mask_color = (255, 77, 166) # pinkish
179
+ annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45)
180
+
181
+ data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
182
+ return jsonify({
183
+ "annotated": data_url,
184
+ "confidences": kept_conf,
185
+ "count": len(kept_conf)
186
+ })
187
+
188
+ if __name__ == "__main__":
189
+ # warm up model on startup (non-blocking)
190
+ try:
191
+ init_model()
192
+ except Exception as e:
193
+ print("Model init warning:", e)
194
+ app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)