Subh775 commited on
Commit
f8a9f51
·
verified ·
1 Parent(s): 9994b0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +271 -171
app.py CHANGED
@@ -1,194 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import io
3
- import base64
4
- import tempfile
5
- import threading
6
- from PIL import Image, ImageDraw, ImageFont
7
  import numpy as np
8
- from flask import Flask, request, jsonify, send_from_directory
9
  import requests
 
 
 
10
 
11
- # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices
12
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
-
14
- # --- model import (ensure rfdetr package is available in requirements) ---
15
- try:
16
- from rfdetr import RFDETRSegPreview
17
- except Exception as e:
18
- raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e
19
-
20
- app = Flask(__name__, static_folder="static", static_url_path="/")
21
 
22
- # HF checkpoint raw resolve URL (use the 'resolve/main' raw link)
23
- CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
24
- CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
25
-
26
- MODEL_LOCK = threading.Lock()
27
- MODEL = None
28
 
 
29
  def download_file(url: str, dst: str):
 
30
  if os.path.exists(dst):
 
31
  return dst
32
  print(f"[INFO] Downloading weights from {url} ...")
33
- r = requests.get(url, stream=True, timeout=60)
34
  r.raise_for_status()
35
- with open(dst, "wb") as fh:
36
  for chunk in r.iter_content(chunk_size=8192):
37
- if chunk:
38
- fh.write(chunk)
39
  print("[INFO] Download complete.")
40
  return dst
41
 
42
- def init_model():
43
- global MODEL
44
- with MODEL_LOCK:
45
- if MODEL is None:
46
- # Ensure model checkpoint
47
- try:
48
- download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
49
- except Exception as e:
50
- print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.")
51
- # continue; model may fallback to default weights
52
- print("[INFO] Loading RF-DETR model (CPU mode)...")
53
- MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
54
- try:
55
- MODEL.optimize_for_inference()
56
- except Exception:
57
- # optimization may fail on CPU or if not implemented; ignore
58
- pass
59
- print("[INFO] Model ready.")
60
- return MODEL
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @app.route("/")
63
- def index():
64
- return send_from_directory("static", "index.html")
65
-
66
- def decode_data_url(data_url: str) -> Image.Image:
67
- if data_url.startswith("data:"):
68
- header, b64 = data_url.split(",", 1)
69
- data = base64.b64decode(b64)
70
- return Image.open(io.BytesIO(data)).convert("RGB")
71
- else:
72
- # assume plain base64 or path
73
- data = base64.b64decode(data_url)
74
- return Image.open(io.BytesIO(data)).convert("RGB")
75
-
76
- def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"):
77
- buf = io.BytesIO()
78
- pil_img.save(buf, format=fmt)
79
- b = base64.b64encode(buf.getvalue()).decode("ascii")
80
- return f"data:image/{fmt.lower()};base64,{b}"
81
-
82
- def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.01, mask_color=(255,77,166), alpha=0.45):
83
- """
84
- masks: either list of HxW bool arrays or numpy array (N,H,W)
85
- confidences: list of floats
86
- Returns annotated PIL image and list of kept confidences and count.
87
- """
88
- base = pil_img.convert("RGBA")
89
- W, H = base.size
90
-
91
- # Normalize masks to N,H,W
92
- if masks is None:
93
- return base, []
94
-
95
- if isinstance(masks, list):
96
- masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
97
- else:
98
- masks_arr = np.asarray(masks)
99
- # masks might be (H,W,N) -> transpose
100
- if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
101
- masks_arr = masks_arr.transpose(2, 0, 1)
102
-
103
- # create overlay
104
- overlay = Image.new("RGBA", (W, H), (0,0,0,0))
105
- draw = ImageDraw.Draw(overlay)
106
-
107
- kept_confidences = []
108
- for i in range(masks_arr.shape[0]):
109
- conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0
110
- if conf < threshold:
111
- continue
112
- mask = masks_arr[i].astype(np.uint8) * 255
113
- mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST)
114
- # create colored mask image
115
- color_layer = Image.new("RGBA", (W,H), mask_color + (0,))
116
- # put alpha using mask
117
- color_layer.putalpha(mask_img.point(lambda p: int(p * alpha)))
118
- overlay = Image.alpha_composite(overlay, color_layer)
119
- kept_confidences.append(conf)
120
-
121
- # composite
122
- annotated = Image.alpha_composite(base, overlay)
123
-
124
- # add confidence text (show highest kept confidence)
125
- if len(kept_confidences) > 0:
126
- best = max(kept_confidences)
127
- draw = ImageDraw.Draw(annotated)
128
- try:
129
- # Try to use a builtin font
130
- font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30))
131
- except Exception:
132
- font = ImageFont.load_default()
133
- text = f"Confidence: {best:.2f}"
134
- # draw background box for text
135
- tw, th = draw.textsize(text, font=font)
136
- pad = 8
137
- draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180))
138
- draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255))
139
- return annotated.convert("RGB"), kept_confidences
140
 
141
  @app.route("/predict", methods=["POST"])
142
  def predict():
143
- payload = request.get_json(force=True)
144
- if not payload or "image" not in payload:
145
- return jsonify({"error": "Missing image"}), 400
146
- conf = float(payload.get("conf", 0.25))
147
-
148
- # ensure model ready
149
- model = init_model()
150
-
151
- # decode image
152
- try:
153
- pil = decode_data_url(payload["image"])
154
- except Exception as e:
155
- return jsonify({"error": f"Invalid image: {e}"}), 400
156
-
157
- # perform prediction (model.predict expects PIL image)
158
- try:
159
- detections = model.predict(pil, threshold=0.0) # we filter using conf manually
160
- except Exception as e:
161
- return jsonify({"error": f"Inference failure: {e}"}), 500
162
-
163
- # extract masks and confidences
164
- masks = getattr(detections, "masks", None)
165
- confidences = []
166
- # attempt to read per-instance confidence
167
- try:
168
- confidences = [float(x) for x in getattr(detections, "confidence", [])]
169
- except Exception:
170
- # fallback: attempt attribute 'scores' or 'scores_' or generate ones
171
- confidences = []
172
- try:
173
- confidences = [float(x) for x in getattr(detections, "scores", [])]
174
- except Exception:
175
- confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0)
176
-
177
- # overlay mask with pink-red color
178
- mask_color = (255, 77, 166) # pinkish
179
- annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45)
180
-
181
- data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
182
- return jsonify({
183
- "annotated": data_url,
184
- "confidences": kept_conf,
185
- "count": len(kept_conf)
186
- })
187
 
188
  if __name__ == "__main__":
189
- # warm up model on startup (non-blocking)
190
- try:
191
- init_model()
192
- except Exception as e:
193
- print("Model init warning:", e)
194
- app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
 
1
+ # import os
2
+ # import io
3
+ # import base64
4
+ # import tempfile
5
+ # import threading
6
+ # from PIL import Image, ImageDraw, ImageFont
7
+ # import numpy as np
8
+ # from flask import Flask, request, jsonify, send_from_directory
9
+ # import requests
10
+
11
+ # # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices
12
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
+
14
+ # # --- model import (ensure rfdetr package is available in requirements) ---
15
+ # try:
16
+ # from rfdetr import RFDETRSegPreview
17
+ # except Exception as e:
18
+ # raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e
19
+
20
+ # app = Flask(__name__, static_folder="static", static_url_path="/")
21
+
22
+ # # HF checkpoint raw resolve URL (use the 'resolve/main' raw link)
23
+ # CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
24
+ # CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
25
+
26
+ # MODEL_LOCK = threading.Lock()
27
+ # MODEL = None
28
+
29
+ # def download_file(url: str, dst: str):
30
+ # if os.path.exists(dst):
31
+ # return dst
32
+ # print(f"[INFO] Downloading weights from {url} ...")
33
+ # r = requests.get(url, stream=True, timeout=60)
34
+ # r.raise_for_status()
35
+ # with open(dst, "wb") as fh:
36
+ # for chunk in r.iter_content(chunk_size=8192):
37
+ # if chunk:
38
+ # fh.write(chunk)
39
+ # print("[INFO] Download complete.")
40
+ # return dst
41
+
42
+ # def init_model():
43
+ # global MODEL
44
+ # with MODEL_LOCK:
45
+ # if MODEL is None:
46
+ # # Ensure model checkpoint
47
+ # try:
48
+ # download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
49
+ # except Exception as e:
50
+ # print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.")
51
+ # # continue; model may fallback to default weights
52
+ # print("[INFO] Loading RF-DETR model (CPU mode)...")
53
+ # MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
54
+ # try:
55
+ # MODEL.optimize_for_inference()
56
+ # except Exception:
57
+ # # optimization may fail on CPU or if not implemented; ignore
58
+ # pass
59
+ # print("[INFO] Model ready.")
60
+ # return MODEL
61
+
62
+ # @app.route("/")
63
+ # def index():
64
+ # return send_from_directory("static", "index.html")
65
+
66
+ # def decode_data_url(data_url: str) -> Image.Image:
67
+ # if data_url.startswith("data:"):
68
+ # header, b64 = data_url.split(",", 1)
69
+ # data = base64.b64decode(b64)
70
+ # return Image.open(io.BytesIO(data)).convert("RGB")
71
+ # else:
72
+ # # assume plain base64 or path
73
+ # data = base64.b64decode(data_url)
74
+ # return Image.open(io.BytesIO(data)).convert("RGB")
75
+
76
+ # def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"):
77
+ # buf = io.BytesIO()
78
+ # pil_img.save(buf, format=fmt)
79
+ # b = base64.b64encode(buf.getvalue()).decode("ascii")
80
+ # return f"data:image/{fmt.lower()};base64,{b}"
81
+
82
+ # def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.01, mask_color=(255,77,166), alpha=0.45):
83
+ # """
84
+ # masks: either list of HxW bool arrays or numpy array (N,H,W)
85
+ # confidences: list of floats
86
+ # Returns annotated PIL image and list of kept confidences and count.
87
+ # """
88
+ # base = pil_img.convert("RGBA")
89
+ # W, H = base.size
90
+
91
+ # # Normalize masks to N,H,W
92
+ # if masks is None:
93
+ # return base, []
94
+
95
+ # if isinstance(masks, list):
96
+ # masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
97
+ # else:
98
+ # masks_arr = np.asarray(masks)
99
+ # # masks might be (H,W,N) -> transpose
100
+ # if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
101
+ # masks_arr = masks_arr.transpose(2, 0, 1)
102
+
103
+ # # create overlay
104
+ # overlay = Image.new("RGBA", (W, H), (0,0,0,0))
105
+ # draw = ImageDraw.Draw(overlay)
106
+
107
+ # kept_confidences = []
108
+ # for i in range(masks_arr.shape[0]):
109
+ # conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0
110
+ # if conf < threshold:
111
+ # continue
112
+ # mask = masks_arr[i].astype(np.uint8) * 255
113
+ # mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST)
114
+ # # create colored mask image
115
+ # color_layer = Image.new("RGBA", (W,H), mask_color + (0,))
116
+ # # put alpha using mask
117
+ # color_layer.putalpha(mask_img.point(lambda p: int(p * alpha)))
118
+ # overlay = Image.alpha_composite(overlay, color_layer)
119
+ # kept_confidences.append(conf)
120
+
121
+ # # composite
122
+ # annotated = Image.alpha_composite(base, overlay)
123
+
124
+ # # add confidence text (show highest kept confidence)
125
+ # if len(kept_confidences) > 0:
126
+ # best = max(kept_confidences)
127
+ # draw = ImageDraw.Draw(annotated)
128
+ # try:
129
+ # # Try to use a builtin font
130
+ # font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30))
131
+ # except Exception:
132
+ # font = ImageFont.load_default()
133
+ # text = f"Confidence: {best:.2f}"
134
+ # # draw background box for text
135
+ # tw, th = draw.textsize(text, font=font)
136
+ # pad = 8
137
+ # draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180))
138
+ # draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255))
139
+ # return annotated.convert("RGB"), kept_confidences
140
+
141
+ # @app.route("/predict", methods=["POST"])
142
+ # def predict():
143
+ # payload = request.get_json(force=True)
144
+ # if not payload or "image" not in payload:
145
+ # return jsonify({"error": "Missing image"}), 400
146
+ # conf = float(payload.get("conf", 0.25))
147
+
148
+ # # ensure model ready
149
+ # model = init_model()
150
+
151
+ # # decode image
152
+ # try:
153
+ # pil = decode_data_url(payload["image"])
154
+ # except Exception as e:
155
+ # return jsonify({"error": f"Invalid image: {e}"}), 400
156
+
157
+ # # perform prediction (model.predict expects PIL image)
158
+ # try:
159
+ # detections = model.predict(pil, threshold=0.0) # we filter using conf manually
160
+ # except Exception as e:
161
+ # return jsonify({"error": f"Inference failure: {e}"}), 500
162
+
163
+ # # extract masks and confidences
164
+ # masks = getattr(detections, "masks", None)
165
+ # confidences = []
166
+ # # attempt to read per-instance confidence
167
+ # try:
168
+ # confidences = [float(x) for x in getattr(detections, "confidence", [])]
169
+ # except Exception:
170
+ # # fallback: attempt attribute 'scores' or 'scores_' or generate ones
171
+ # confidences = []
172
+ # try:
173
+ # confidences = [float(x) for x in getattr(detections, "scores", [])]
174
+ # except Exception:
175
+ # confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0)
176
+
177
+ # # overlay mask with pink-red color
178
+ # mask_color = (255, 77, 166) # pinkish
179
+ # annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45)
180
+
181
+ # data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
182
+ # return jsonify({
183
+ # "annotated": data_url,
184
+ # "confidences": kept_conf,
185
+ # "count": len(kept_conf)
186
+ # })
187
+
188
+ # if __name__ == "__main__":
189
+ # # warm up model on startup (non-blocking)
190
+ # try:
191
+ # init_model()
192
+ # except Exception as e:
193
+ # print("Model init warning:", e)
194
+ # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
195
+
196
+
197
  import os
198
  import io
 
 
 
 
199
  import numpy as np
200
+ from PIL import Image
201
  import requests
202
+ import supervision as sv
203
+ from flask import Flask, request, jsonify, send_file
204
+ from rfdetr import RFDETRSegPreview
205
 
206
+ app = Flask(__name__)
 
 
 
 
 
 
 
 
 
207
 
208
+ # ---- CONFIG ----
209
+ WEIGHTS_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
210
+ WEIGHTS_PATH = "/tmp/checkpoint_best_total.pth"
 
 
 
211
 
212
+ # ---- HELPERS ----
213
  def download_file(url: str, dst: str):
214
+ """Download model weights if not already cached."""
215
  if os.path.exists(dst):
216
+ print(f"[INFO] Weights already exist at {dst}")
217
  return dst
218
  print(f"[INFO] Downloading weights from {url} ...")
219
+ r = requests.get(url, stream=True)
220
  r.raise_for_status()
221
+ with open(dst, "wb") as f:
222
  for chunk in r.iter_content(chunk_size=8192):
223
+ f.write(chunk)
 
224
  print("[INFO] Download complete.")
225
  return dst
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ def annotate_segmentation(image: Image.Image, detections: sv.Detections):
229
+ """Overlay colored masks and confidence scores."""
230
+ palette = sv.ColorPalette.from_hex([
231
+ "#ff9b00", "#ff8080", "#ff66b2", "#b266ff",
232
+ "#9999ff", "#3399ff", "#33ff99", "#99ff00"
233
+ ])
234
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
235
+
236
+ mask_annotator = sv.MaskAnnotator(color=palette)
237
+ polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
238
+ label_annotator = sv.LabelAnnotator(
239
+ color=palette,
240
+ text_color=sv.Color.BLACK,
241
+ text_scale=text_scale,
242
+ text_position=sv.Position.CENTER_OF_MASS
243
+ )
244
+
245
+ # Only show confidence (no class id)
246
+ labels = [f"{conf:.2f}" for conf in detections.confidence]
247
+
248
+ annotated = image.copy()
249
+ annotated = mask_annotator.annotate(annotated, detections)
250
+ annotated = polygon_annotator.annotate(annotated, detections)
251
+ annotated = label_annotator.annotate(annotated, detections, labels)
252
+ return annotated
253
+
254
+
255
+ # ---- MODEL INITIALIZATION ----
256
+ print("[INFO] Loading RF-DETR model (CPU mode)...")
257
+ download_file(WEIGHTS_URL, WEIGHTS_PATH)
258
+ model = RFDETRSegPreview(pretrain_weights=WEIGHTS_PATH)
259
+ try:
260
+ model.optimize_for_inference()
261
+ except Exception as e:
262
+ print(f"[WARN] optimize_for_inference() skipped: {e}")
263
+ print("[INFO] Model ready.")
264
+
265
+
266
+ # ---- ROUTES ----
267
  @app.route("/")
268
+ def home():
269
+ return jsonify({"message": "RF-DETR Segmentation API is running."})
270
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  @app.route("/predict", methods=["POST"])
273
  def predict():
274
+ """Accepts an image file and returns annotated segmentation overlay."""
275
+ if "file" not in request.files:
276
+ return jsonify({"error": "No file uploaded"}), 400
277
+
278
+ file = request.files["file"]
279
+ image = Image.open(file.stream).convert("RGB")
280
+ print(f"[INFO] Image received for inference: {file.filename}")
281
+
282
+ detections = model.predict(image, threshold=0.3)
283
+ print(f"[INFO] Detections found: {len(getattr(detections, 'boxes', []))}")
284
+
285
+ annotated = annotate_segmentation(image, detections)
286
+
287
+ buf = io.BytesIO()
288
+ annotated.save(buf, format="PNG")
289
+ buf.seek(0)
290
+ return send_file(buf, mimetype="image/png")
291
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  if __name__ == "__main__":
294
+ app.run(host="0.0.0.0", port=7860)