Subh775 commited on
Commit
84395d7
·
verified ·
1 Parent(s): c8d1052

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -139
app.py CHANGED
@@ -193,32 +193,34 @@
193
  # print("Model init warning:", e)
194
  # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
195
 
 
 
196
  import os
197
  import io
198
  import base64
199
  import threading
200
- import tempfile
201
  import traceback
202
  from typing import Optional
203
 
204
- from flask import Flask, request, jsonify, send_from_directory, send_file
205
- from PIL import Image, ImageDraw, ImageFont
206
  import numpy as np
207
  import requests
 
208
 
209
- # Set writable cache dirs to avoid matplotlib/fontconfig warnings in containers
210
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
211
  os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
212
- # Ensure CPU-only (do not accidentally use GPU)
213
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
 
 
 
 
 
 
214
 
215
- # --- Imports that may trigger the above warnings ---
216
- try:
217
- import supervision as sv
218
- from rfdetr import RFDETRSegPreview
219
- except Exception as e:
220
- # Provide a clearer error at startup if imports fail
221
- raise RuntimeError(f"Required library import failed: {e}")
222
 
223
  app = Flask(__name__, static_folder="static", static_url_path="/")
224
 
@@ -231,10 +233,12 @@ MODEL = None
231
 
232
 
233
  def download_file(url: str, dst: str, chunk_size: int = 8192):
 
234
  if os.path.exists(dst) and os.path.getsize(dst) > 0:
 
235
  return dst
236
  print(f"[INFO] Downloading weights from {url} -> {dst}")
237
- r = requests.get(url, stream=True, timeout=60)
238
  r.raise_for_status()
239
  with open(dst, "wb") as fh:
240
  for chunk in r.iter_content(chunk_size=chunk_size):
@@ -245,27 +249,29 @@ def download_file(url: str, dst: str, chunk_size: int = 8192):
245
 
246
 
247
  def init_model():
248
- """
249
- Lazily initialize the RF-DETR model and cache it in global MODEL.
250
- Thread-safe.
251
- """
252
  global MODEL
253
  with MODEL_LOCK:
254
  if MODEL is not None:
255
  return MODEL
256
  try:
257
- # ensure checkpoint present (best-effort)
258
  try:
259
  download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
260
  except Exception as e:
261
  print("[WARN] Failed to download checkpoint:", e)
 
 
262
 
263
  print("[INFO] Loading RF-DETR model (CPU mode)...")
264
- MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
 
 
265
  try:
266
  MODEL.optimize_for_inference()
267
  except Exception as e:
268
  print("[WARN] optimize_for_inference() skipped/failed:", e)
 
269
  print("[INFO] Model ready.")
270
  return MODEL
271
  except Exception:
@@ -274,14 +280,11 @@ def init_model():
274
 
275
 
276
  def decode_data_url(data_url: str) -> Image.Image:
277
- """
278
- Accepts a data URL (data:image/png;base64,...) or raw base64 and returns PIL.Image (RGB)
279
- """
280
  if data_url.startswith("data:"):
281
  _, b64 = data_url.split(",", 1)
282
  data = base64.b64decode(b64)
283
  else:
284
- # assume raw base64 or binary string
285
  try:
286
  data = base64.b64decode(data_url)
287
  except Exception:
@@ -290,117 +293,54 @@ def decode_data_url(data_url: str) -> Image.Image:
290
 
291
 
292
  def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
 
293
  buf = io.BytesIO()
294
  pil_img.save(buf, format=fmt)
295
  return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
296
 
297
 
298
- def overlay_mask_on_image(pil_img: Image.Image, detections, threshold: float = 0.25,
299
- mask_color=(255, 77, 166), alpha=0.45):
300
  """
301
- Create annotated PIL image by overlaying per-instance masks (pink) and polygon borders,
302
- and add confidence text (best confidence) on the image.
303
- Uses supervision-like masks if available, otherwise attempts to use detections.masks.
304
- Returns (annotated_pil_rgb, kept_confidences_list)
305
  """
306
- base = pil_img.convert("RGBA")
307
- W, H = base.size
308
-
309
- masks = getattr(detections, "masks", None)
310
- confidences = []
311
- try:
312
- confidences = [float(x) for x in getattr(detections, "confidence", [])]
313
- except Exception:
314
- # fallback to 'scores' or empty
315
- try:
316
- confidences = [float(x) for x in getattr(detections, "scores", [])]
317
- except Exception:
318
- confidences = []
319
-
320
- if masks is None:
321
- # no masks -> return original image and empty list
322
- return pil_img.convert("RGB"), []
323
-
324
- # Normalize mask array to (N, H, W)
325
- if isinstance(masks, list):
326
- masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
327
- else:
328
- masks_arr = np.asarray(masks)
329
- # some outputs might be (H, W, N)
330
- if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
331
- masks_arr = masks_arr.transpose(2, 0, 1)
332
-
333
- # overlay image we will composite
334
- overlay = Image.new("RGBA", (W, H), (0, 0, 0, 0))
335
- kept_confidences = []
336
-
337
- for i in range(masks_arr.shape[0]):
338
- conf = confidences[i] if i < len(confidences) else 1.0
339
- if conf < threshold:
340
- continue
341
- mask = masks_arr[i].astype(np.uint8) * 255
342
- mask_img = Image.fromarray(mask).convert("L")
343
- # if mask size doesn't match, resize
344
- if mask_img.size != (W, H):
345
- mask_img = mask_img.resize((W, H), resample=Image.NEAREST)
346
-
347
- # color layer with alpha
348
- color_layer = Image.new("RGBA", (W, H), mask_color + (0,))
349
- # compute per-pixel alpha from mask (0..255) scaled by alpha
350
- alpha_mask = mask_img.point(lambda p: int(p * alpha))
351
- color_layer.putalpha(alpha_mask)
352
- overlay = Image.alpha_composite(overlay, color_layer)
353
- kept_confidences.append(float(conf))
354
-
355
- # draw polygon outlines for visual crispness using supervision polygonifier if available
356
- try:
357
- # try to use supervision polygonizer if detections contains polygons
358
- # fallback: create thin white outline by expanding mask boundaries
359
- from skimage import measure
360
- draw = ImageDraw.Draw(overlay)
361
- for i in range(masks_arr.shape[0]):
362
- conf = confidences[i] if i < len(confidences) else 1.0
363
- if conf < threshold:
364
- continue
365
- mask = masks_arr[i].astype(np.uint8)
366
- # resize mask for contour if needed
367
- if mask.shape[1] != W or mask.shape[0] != H:
368
- mask_pil = Image.fromarray((mask * 255).astype(np.uint8)).resize((W, H), resample=Image.NEAREST)
369
- mask = np.asarray(mask_pil).astype(np.uint8) // 255
370
- contours = measure.find_contours(mask, 0.5)
371
- for contour in contours:
372
- # contour is list of (row, col) -> convert to (x, y)
373
- pts = [(float(c[1]), float(c[0])) for c in contour]
374
- if len(pts) >= 3:
375
- # draw white outline
376
- draw.line(pts + [pts[0]], fill=(255, 255, 255, 255), width=2)
377
- except Exception:
378
- # ignore if skimage not available; outlines are optional
379
- pass
380
-
381
- annotated = Image.alpha_composite(base, overlay).convert("RGBA")
382
-
383
- # annotate best confidence text (top-left)
384
- if kept_confidences:
385
- best = max(kept_confidences)
386
- draw = ImageDraw.Draw(annotated)
387
- try:
388
- font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(14, W // 32))
389
- except Exception:
390
- font = ImageFont.load_default()
391
- text = f"Confidence: {best:.2f}"
392
- tw, th = draw.textsize(text, font=font)
393
- pad = 6
394
- rect = [6, 6, 6 + tw + pad, 6 + th + pad]
395
- draw.rectangle(rect, fill=(0, 0, 0, 180))
396
- draw.text((6 + pad // 2, 6 + pad // 2), text, font=font, fill=(255, 255, 255, 255))
397
-
398
- return annotated.convert("RGB"), kept_confidences
399
 
400
 
401
  @app.route("/", methods=["GET"])
402
  def index():
403
- # serve the static UI file if present
404
  index_path = os.path.join(app.static_folder or "static", "index.html")
405
  if os.path.exists(index_path):
406
  return send_from_directory(app.static_folder, "index.html")
@@ -421,11 +361,11 @@ def predict():
421
  except Exception as e:
422
  return jsonify({"error": f"Model initialization failed: {e}"}), 500
423
 
424
- # parse input
425
  img: Optional[Image.Image] = None
426
  conf_threshold = 0.25
427
 
428
- # If form file uploaded
429
  if "file" in request.files:
430
  file = request.files["file"]
431
  try:
@@ -434,7 +374,7 @@ def predict():
434
  return jsonify({"error": f"Invalid uploaded image: {e}"}), 400
435
  conf_threshold = float(request.form.get("conf", conf_threshold))
436
  else:
437
- # try JSON payload
438
  payload = request.get_json(silent=True)
439
  if not payload or "image" not in payload:
440
  return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
@@ -444,28 +384,69 @@ def predict():
444
  return jsonify({"error": f"Invalid image data: {e}"}), 400
445
  conf_threshold = float(payload.get("conf", conf_threshold))
446
 
447
- # run inference
 
 
 
 
 
 
 
 
 
448
  try:
449
- # set threshold=0.0 in model predict since we'll manually filter by conf_threshold
450
- detections = model.predict(img, threshold=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  except Exception as e:
452
  traceback.print_exc()
453
  return jsonify({"error": f"Inference failed: {e}"}), 500
454
 
455
- # overlay masks and extract confidences > threshold
456
- annotated_pil, kept_conf = overlay_mask_on_image(img, detections, threshold=conf_threshold,
457
- mask_color=(255, 77, 166), alpha=0.45)
458
-
459
- data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
460
- return jsonify({"annotated": data_url, "confidences": kept_conf, "count": len(kept_conf)})
461
-
462
 
463
  if __name__ == "__main__":
464
- # Warm model in a background thread to avoid blocking the container start logs too long
465
  def warm():
466
  try:
 
467
  init_model()
 
468
  except Exception as e:
469
- print("Model warmup failed:", e)
 
 
470
  threading.Thread(target=warm, daemon=True).start()
471
- app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
 
 
 
 
 
 
 
193
  # print("Model init warning:", e)
194
  # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
195
 
196
+
197
+
198
  import os
199
  import io
200
  import base64
201
  import threading
 
202
  import traceback
203
  from typing import Optional
204
 
205
+ from flask import Flask, request, jsonify, send_from_directory
206
+ from PIL import Image
207
  import numpy as np
208
  import requests
209
+ import torch
210
 
211
+ # Set environment variables for CPU-only operation
212
  os.environ.setdefault("MPLCONFIGDIR", "/tmp/.matplotlib")
213
  os.environ.setdefault("FONTCONFIG_PATH", "/tmp/.fontconfig")
 
214
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
215
+ os.environ.setdefault("OMP_NUM_THREADS", "4")
216
+ os.environ.setdefault("MKL_NUM_THREADS", "4")
217
+ os.environ.setdefault("OPENBLAS_NUM_THREADS", "4")
218
+
219
+ # Limit torch threads
220
+ torch.set_num_threads(4)
221
 
222
+ import supervision as sv
223
+ from rfdetr import RFDETRSegPreview
 
 
 
 
 
224
 
225
  app = Flask(__name__, static_folder="static", static_url_path="/")
226
 
 
233
 
234
 
235
  def download_file(url: str, dst: str, chunk_size: int = 8192):
236
+ """Download file if not exists"""
237
  if os.path.exists(dst) and os.path.getsize(dst) > 0:
238
+ print(f"[INFO] Checkpoint already exists at {dst}")
239
  return dst
240
  print(f"[INFO] Downloading weights from {url} -> {dst}")
241
+ r = requests.get(url, stream=True, timeout=120)
242
  r.raise_for_status()
243
  with open(dst, "wb") as fh:
244
  for chunk in r.iter_content(chunk_size=chunk_size):
 
249
 
250
 
251
  def init_model():
252
+ """Lazily initialize the RF-DETR model and cache it in global MODEL."""
 
 
 
253
  global MODEL
254
  with MODEL_LOCK:
255
  if MODEL is not None:
256
  return MODEL
257
  try:
258
+ # Ensure checkpoint present
259
  try:
260
  download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
261
  except Exception as e:
262
  print("[WARN] Failed to download checkpoint:", e)
263
+ if not os.path.exists(CHECKPOINT_PATH):
264
+ raise
265
 
266
  print("[INFO] Loading RF-DETR model (CPU mode)...")
267
+ MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH)
268
+
269
+ # Try to optimize for inference
270
  try:
271
  MODEL.optimize_for_inference()
272
  except Exception as e:
273
  print("[WARN] optimize_for_inference() skipped/failed:", e)
274
+
275
  print("[INFO] Model ready.")
276
  return MODEL
277
  except Exception:
 
280
 
281
 
282
  def decode_data_url(data_url: str) -> Image.Image:
283
+ """Decode data URL to PIL Image"""
 
 
284
  if data_url.startswith("data:"):
285
  _, b64 = data_url.split(",", 1)
286
  data = base64.b64decode(b64)
287
  else:
 
288
  try:
289
  data = base64.b64decode(data_url)
290
  except Exception:
 
293
 
294
 
295
  def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG") -> str:
296
+ """Encode PIL Image to data URL"""
297
  buf = io.BytesIO()
298
  pil_img.save(buf, format=fmt)
299
  return "data:image/{};base64,".format(fmt.lower()) + base64.b64encode(buf.getvalue()).decode("ascii")
300
 
301
 
302
+ def annotate_segmentation(image: Image.Image, detections: sv.Detections) -> Image.Image:
 
303
  """
304
+ Annotate image with segmentation masks using supervision library.
305
+ This matches the visualization from rfdetr_seg_infer.py script.
 
 
306
  """
307
+ # Define color palette
308
+ palette = sv.ColorPalette.from_hex([
309
+ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
310
+ "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00",
311
+ ])
312
+
313
+ # Calculate optimal text scale based on image resolution
314
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
315
+
316
+ # Create annotators
317
+ mask_annotator = sv.MaskAnnotator(color=palette)
318
+ polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
319
+ label_annotator = sv.LabelAnnotator(
320
+ color=palette,
321
+ text_color=sv.Color.BLACK,
322
+ text_scale=text_scale,
323
+ text_position=sv.Position.CENTER_OF_MASS
324
+ )
325
+
326
+ # Create labels with class IDs and confidence scores
327
+ labels = [
328
+ f"Tulsi {float(conf):.2f}"
329
+ for conf in detections.confidence
330
+ ]
331
+
332
+ # Apply annotations
333
+ out = image.copy()
334
+ out = mask_annotator.annotate(out, detections)
335
+ out = polygon_annotator.annotate(out, detections)
336
+ out = label_annotator.annotate(out, detections, labels)
337
+
338
+ return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
 
341
  @app.route("/", methods=["GET"])
342
  def index():
343
+ """Serve the static UI"""
344
  index_path = os.path.join(app.static_folder or "static", "index.html")
345
  if os.path.exists(index_path):
346
  return send_from_directory(app.static_folder, "index.html")
 
361
  except Exception as e:
362
  return jsonify({"error": f"Model initialization failed: {e}"}), 500
363
 
364
+ # Parse input
365
  img: Optional[Image.Image] = None
366
  conf_threshold = 0.25
367
 
368
+ # Check if file uploaded
369
  if "file" in request.files:
370
  file = request.files["file"]
371
  try:
 
374
  return jsonify({"error": f"Invalid uploaded image: {e}"}), 400
375
  conf_threshold = float(request.form.get("conf", conf_threshold))
376
  else:
377
+ # Try JSON payload
378
  payload = request.get_json(silent=True)
379
  if not payload or "image" not in payload:
380
  return jsonify({"error": "No image provided. Upload 'file' or JSON with 'image' data-url."}), 400
 
384
  return jsonify({"error": f"Invalid image data: {e}"}), 400
385
  conf_threshold = float(payload.get("conf", conf_threshold))
386
 
387
+ # Optionally downscale large images to reduce memory usage
388
+ MAX_SIZE = 1024
389
+ if max(img.size) > MAX_SIZE:
390
+ w, h = img.size
391
+ scale = MAX_SIZE / float(max(w, h))
392
+ new_w, new_h = int(round(w * scale)), int(round(h * scale))
393
+ img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
394
+ print(f"[INFO] Resized image to {new_w}x{new_h}")
395
+
396
+ # Run inference with no_grad for memory efficiency
397
  try:
398
+ with torch.no_grad():
399
+ detections = model.predict(img, threshold=conf_threshold)
400
+
401
+ print(f"[INFO] Detected {len(detections)} objects")
402
+
403
+ # Check if detections exist
404
+ if len(detections) == 0:
405
+ print("[INFO] No detections above threshold")
406
+ # Return original image with message
407
+ data_url = encode_pil_to_dataurl(img, fmt="PNG")
408
+ return jsonify({
409
+ "annotated": data_url,
410
+ "confidences": [],
411
+ "count": 0
412
+ })
413
+
414
+ # Annotate image using supervision library
415
+ annotated_pil = annotate_segmentation(img, detections)
416
+
417
+ # Extract confidence scores
418
+ confidences = [float(conf) for conf in detections.confidence]
419
+
420
+ # Encode to data URL
421
+ data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
422
+
423
+ return jsonify({
424
+ "annotated": data_url,
425
+ "confidences": confidences,
426
+ "count": len(confidences)
427
+ })
428
+
429
  except Exception as e:
430
  traceback.print_exc()
431
  return jsonify({"error": f"Inference failed: {e}"}), 500
432
 
 
 
 
 
 
 
 
433
 
434
  if __name__ == "__main__":
435
+ # Warm model in background thread
436
  def warm():
437
  try:
438
+ print("[INFO] Starting model warmup...")
439
  init_model()
440
+ print("[INFO] Model warmup complete")
441
  except Exception as e:
442
+ print(f"[ERROR] Model warmup failed: {e}")
443
+ traceback.print_exc()
444
+
445
  threading.Thread(target=warm, daemon=True).start()
446
+
447
+ # Run Flask app
448
+ app.run(
449
+ host="0.0.0.0",
450
+ port=int(os.environ.get("PORT", 7860)),
451
+ debug=False
452
+ )