Subh775's picture
Create app.py
db12d23 verified
raw
history blame
7.2 kB
import os
import io
import base64
import tempfile
import threading
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from flask import Flask, request, jsonify, send_from_directory
import requests
# Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# --- model import (ensure rfdetr package is available in requirements) ---
try:
from rfdetr import RFDETRSegPreview
except Exception as e:
raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e
app = Flask(__name__, static_folder="static", static_url_path="/")
# HF checkpoint raw resolve URL (use the 'resolve/main' raw link)
CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-2/resolve/main/checkpoint_best_total.pth"
CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
MODEL_LOCK = threading.Lock()
MODEL = None
def download_file(url: str, dst: str):
if os.path.exists(dst):
return dst
print(f"[INFO] Downloading weights from {url} ...")
r = requests.get(url, stream=True, timeout=60)
r.raise_for_status()
with open(dst, "wb") as fh:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
fh.write(chunk)
print("[INFO] Download complete.")
return dst
def init_model():
global MODEL
with MODEL_LOCK:
if MODEL is None:
# Ensure model checkpoint
try:
download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
except Exception as e:
print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.")
# continue; model may fallback to default weights
print("[INFO] Loading RF-DETR model (CPU mode)...")
MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
try:
MODEL.optimize_for_inference()
except Exception:
# optimization may fail on CPU or if not implemented; ignore
pass
print("[INFO] Model ready.")
return MODEL
@app.route("/")
def index():
return send_from_directory("static", "index.html")
def decode_data_url(data_url: str) -> Image.Image:
if data_url.startswith("data:"):
header, b64 = data_url.split(",", 1)
data = base64.b64decode(b64)
return Image.open(io.BytesIO(data)).convert("RGB")
else:
# assume plain base64 or path
data = base64.b64decode(data_url)
return Image.open(io.BytesIO(data)).convert("RGB")
def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"):
buf = io.BytesIO()
pil_img.save(buf, format=fmt)
b = base64.b64encode(buf.getvalue()).decode("ascii")
return f"data:image/{fmt.lower()};base64,{b}"
def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.25, mask_color=(255,77,166), alpha=0.45):
"""
masks: either list of HxW bool arrays or numpy array (N,H,W)
confidences: list of floats
Returns annotated PIL image and list of kept confidences and count.
"""
base = pil_img.convert("RGBA")
W, H = base.size
# Normalize masks to N,H,W
if masks is None:
return base, []
if isinstance(masks, list):
masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
else:
masks_arr = np.asarray(masks)
# masks might be (H,W,N) -> transpose
if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
masks_arr = masks_arr.transpose(2, 0, 1)
# create overlay
overlay = Image.new("RGBA", (W, H), (0,0,0,0))
draw = ImageDraw.Draw(overlay)
kept_confidences = []
for i in range(masks_arr.shape[0]):
conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0
if conf < threshold:
continue
mask = masks_arr[i].astype(np.uint8) * 255
mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST)
# create colored mask image
color_layer = Image.new("RGBA", (W,H), mask_color + (0,))
# put alpha using mask
color_layer.putalpha(mask_img.point(lambda p: int(p * alpha)))
overlay = Image.alpha_composite(overlay, color_layer)
kept_confidences.append(conf)
# composite
annotated = Image.alpha_composite(base, overlay)
# add confidence text (show highest kept confidence)
if len(kept_confidences) > 0:
best = max(kept_confidences)
draw = ImageDraw.Draw(annotated)
try:
# Try to use a builtin font
font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30))
except Exception:
font = ImageFont.load_default()
text = f"Confidence: {best:.2f}"
# draw background box for text
tw, th = draw.textsize(text, font=font)
pad = 8
draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180))
draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255))
return annotated.convert("RGB"), kept_confidences
@app.route("/predict", methods=["POST"])
def predict():
payload = request.get_json(force=True)
if not payload or "image" not in payload:
return jsonify({"error": "Missing image"}), 400
conf = float(payload.get("conf", 0.25))
# ensure model ready
model = init_model()
# decode image
try:
pil = decode_data_url(payload["image"])
except Exception as e:
return jsonify({"error": f"Invalid image: {e}"}), 400
# perform prediction (model.predict expects PIL image)
try:
detections = model.predict(pil, threshold=0.0) # we filter using conf manually
except Exception as e:
return jsonify({"error": f"Inference failure: {e}"}), 500
# extract masks and confidences
masks = getattr(detections, "masks", None)
confidences = []
# attempt to read per-instance confidence
try:
confidences = [float(x) for x in getattr(detections, "confidence", [])]
except Exception:
# fallback: attempt attribute 'scores' or 'scores_' or generate ones
confidences = []
try:
confidences = [float(x) for x in getattr(detections, "scores", [])]
except Exception:
confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0)
# overlay mask with pink-red color
mask_color = (255, 77, 166) # pinkish
annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45)
data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
return jsonify({
"annotated": data_url,
"confidences": kept_conf,
"count": len(kept_conf)
})
if __name__ == "__main__":
# warm up model on startup (non-blocking)
try:
init_model()
except Exception as e:
print("Model init warning:", e)
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)