Subh775's picture
Update app.py
ecee7e2 verified
raw
history blame
10.9 kB
# import os
# import io
# import base64
# import tempfile
# import threading
# from PIL import Image, ImageDraw, ImageFont
# import numpy as np
# from flask import Flask, request, jsonify, send_from_directory
# import requests
# # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
# # --- model import (ensure rfdetr package is available in requirements) ---
# try:
# from rfdetr import RFDETRSegPreview
# except Exception as e:
# raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e
# app = Flask(__name__, static_folder="static", static_url_path="/")
# # HF checkpoint raw resolve URL (use the 'resolve/main' raw link)
# CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
# CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth")
# MODEL_LOCK = threading.Lock()
# MODEL = None
# def download_file(url: str, dst: str):
# if os.path.exists(dst):
# return dst
# print(f"[INFO] Downloading weights from {url} ...")
# r = requests.get(url, stream=True, timeout=60)
# r.raise_for_status()
# with open(dst, "wb") as fh:
# for chunk in r.iter_content(chunk_size=8192):
# if chunk:
# fh.write(chunk)
# print("[INFO] Download complete.")
# return dst
# def init_model():
# global MODEL
# with MODEL_LOCK:
# if MODEL is None:
# # Ensure model checkpoint
# try:
# download_file(CHECKPOINT_URL, CHECKPOINT_PATH)
# except Exception as e:
# print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.")
# # continue; model may fallback to default weights
# print("[INFO] Loading RF-DETR model (CPU mode)...")
# MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None)
# try:
# MODEL.optimize_for_inference()
# except Exception:
# # optimization may fail on CPU or if not implemented; ignore
# pass
# print("[INFO] Model ready.")
# return MODEL
# @app.route("/")
# def index():
# return send_from_directory("static", "index.html")
# def decode_data_url(data_url: str) -> Image.Image:
# if data_url.startswith("data:"):
# header, b64 = data_url.split(",", 1)
# data = base64.b64decode(b64)
# return Image.open(io.BytesIO(data)).convert("RGB")
# else:
# # assume plain base64 or path
# data = base64.b64decode(data_url)
# return Image.open(io.BytesIO(data)).convert("RGB")
# def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"):
# buf = io.BytesIO()
# pil_img.save(buf, format=fmt)
# b = base64.b64encode(buf.getvalue()).decode("ascii")
# return f"data:image/{fmt.lower()};base64,{b}"
# def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.01, mask_color=(255,77,166), alpha=0.45):
# """
# masks: either list of HxW bool arrays or numpy array (N,H,W)
# confidences: list of floats
# Returns annotated PIL image and list of kept confidences and count.
# """
# base = pil_img.convert("RGBA")
# W, H = base.size
# # Normalize masks to N,H,W
# if masks is None:
# return base, []
# if isinstance(masks, list):
# masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0)
# else:
# masks_arr = np.asarray(masks)
# # masks might be (H,W,N) -> transpose
# if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W:
# masks_arr = masks_arr.transpose(2, 0, 1)
# # create overlay
# overlay = Image.new("RGBA", (W, H), (0,0,0,0))
# draw = ImageDraw.Draw(overlay)
# kept_confidences = []
# for i in range(masks_arr.shape[0]):
# conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0
# if conf < threshold:
# continue
# mask = masks_arr[i].astype(np.uint8) * 255
# mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST)
# # create colored mask image
# color_layer = Image.new("RGBA", (W,H), mask_color + (0,))
# # put alpha using mask
# color_layer.putalpha(mask_img.point(lambda p: int(p * alpha)))
# overlay = Image.alpha_composite(overlay, color_layer)
# kept_confidences.append(conf)
# # composite
# annotated = Image.alpha_composite(base, overlay)
# # add confidence text (show highest kept confidence)
# if len(kept_confidences) > 0:
# best = max(kept_confidences)
# draw = ImageDraw.Draw(annotated)
# try:
# # Try to use a builtin font
# font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30))
# except Exception:
# font = ImageFont.load_default()
# text = f"Confidence: {best:.2f}"
# # draw background box for text
# tw, th = draw.textsize(text, font=font)
# pad = 8
# draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180))
# draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255))
# return annotated.convert("RGB"), kept_confidences
# @app.route("/predict", methods=["POST"])
# def predict():
# payload = request.get_json(force=True)
# if not payload or "image" not in payload:
# return jsonify({"error": "Missing image"}), 400
# conf = float(payload.get("conf", 0.25))
# # ensure model ready
# model = init_model()
# # decode image
# try:
# pil = decode_data_url(payload["image"])
# except Exception as e:
# return jsonify({"error": f"Invalid image: {e}"}), 400
# # perform prediction (model.predict expects PIL image)
# try:
# detections = model.predict(pil, threshold=0.0) # we filter using conf manually
# except Exception as e:
# return jsonify({"error": f"Inference failure: {e}"}), 500
# # extract masks and confidences
# masks = getattr(detections, "masks", None)
# confidences = []
# # attempt to read per-instance confidence
# try:
# confidences = [float(x) for x in getattr(detections, "confidence", [])]
# except Exception:
# # fallback: attempt attribute 'scores' or 'scores_' or generate ones
# confidences = []
# try:
# confidences = [float(x) for x in getattr(detections, "scores", [])]
# except Exception:
# confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0)
# # overlay mask with pink-red color
# mask_color = (255, 77, 166) # pinkish
# annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45)
# data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG")
# return jsonify({
# "annotated": data_url,
# "confidences": kept_conf,
# "count": len(kept_conf)
# })
# if __name__ == "__main__":
# # warm up model on startup (non-blocking)
# try:
# init_model()
# except Exception as e:
# print("Model init warning:", e)
# app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)
import os
import io
import numpy as np
from PIL import Image
import requests
import supervision as sv
from flask import Flask, request, jsonify, send_file
from rfdetr import RFDETRSegPreview
app = Flask(__name__)
# ---- CONFIG ----
WEIGHTS_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth"
WEIGHTS_PATH = "/tmp/checkpoint_best_total.pth"
# ---- HELPERS ----
def download_file(url: str, dst: str):
"""Download model weights if not already cached."""
if os.path.exists(dst):
print(f"[INFO] Weights already exist at {dst}")
return dst
print(f"[INFO] Downloading weights from {url} ...")
r = requests.get(url, stream=True)
r.raise_for_status()
with open(dst, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("[INFO] Download complete.")
return dst
def annotate_segmentation(image: Image.Image, detections: sv.Detections):
"""Overlay colored masks and confidence scores."""
palette = sv.ColorPalette.from_hex([
"#ff9b00", "#ff8080", "#ff66b2", "#b266ff",
"#9999ff", "#3399ff", "#33ff99", "#99ff00"
])
text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size)
mask_annotator = sv.MaskAnnotator(color=palette)
polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE)
label_annotator = sv.LabelAnnotator(
color=palette,
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_position=sv.Position.CENTER_OF_MASS
)
# Only show confidence (no class id)
labels = [f"{conf:.2f}" for conf in detections.confidence]
annotated = image.copy()
annotated = mask_annotator.annotate(annotated, detections)
annotated = polygon_annotator.annotate(annotated, detections)
annotated = label_annotator.annotate(annotated, detections, labels)
return annotated
# ---- MODEL INITIALIZATION ----
print("[INFO] Loading RF-DETR model (CPU mode)...")
download_file(WEIGHTS_URL, WEIGHTS_PATH)
model = RFDETRSegPreview(pretrain_weights=WEIGHTS_PATH)
try:
model.optimize_for_inference()
except Exception as e:
print(f"[WARN] optimize_for_inference() skipped: {e}")
print("[INFO] Model ready.")
# ---- ROUTES ----
@app.route("/")
def home():
return jsonify({"message": "RF-DETR Segmentation API is running."})
@app.route("/predict", methods=["POST"])
def predict():
"""Accepts an image file and returns annotated segmentation overlay."""
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
image = Image.open(file.stream).convert("RGB")
print(f"[INFO] Image received for inference: {file.filename}")
detections = model.predict(image, threshold=0.3)
print(f"[INFO] Detections found: {len(getattr(detections, 'boxes', []))}")
annotated = annotate_segmentation(image, detections)
buf = io.BytesIO()
annotated.save(buf, format="PNG")
buf.seek(0)
return send_file(buf, mimetype="image/png")
# if __name__ == "__main__":
# app.run(host="0.0.0.0", port=7860)
if __name__ == "__main__":
# warm up model on startup (non-blocking)
try:
init_model()
except Exception as e:
print("Model init warning:", e)
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False)