Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import base64 | |
| import tempfile | |
| import threading | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| from flask import Flask, request, jsonify, send_from_directory | |
| import requests | |
| # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| # --- model import (ensure rfdetr package is available in requirements) --- | |
| try: | |
| from rfdetr import RFDETRSegPreview | |
| except Exception as e: | |
| raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e | |
| app = Flask(__name__, static_folder="static", static_url_path="/") | |
| # HF checkpoint raw resolve URL (use the 'resolve/main' raw link) | |
| CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-2/resolve/main/checkpoint_best_total.pth" | |
| CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth") | |
| MODEL_LOCK = threading.Lock() | |
| MODEL = None | |
| def download_file(url: str, dst: str): | |
| if os.path.exists(dst): | |
| return dst | |
| print(f"[INFO] Downloading weights from {url} ...") | |
| r = requests.get(url, stream=True, timeout=60) | |
| r.raise_for_status() | |
| with open(dst, "wb") as fh: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| if chunk: | |
| fh.write(chunk) | |
| print("[INFO] Download complete.") | |
| return dst | |
| def init_model(): | |
| global MODEL | |
| with MODEL_LOCK: | |
| if MODEL is None: | |
| # Ensure model checkpoint | |
| try: | |
| download_file(CHECKPOINT_URL, CHECKPOINT_PATH) | |
| except Exception as e: | |
| print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.") | |
| # continue; model may fallback to default weights | |
| print("[INFO] Loading RF-DETR model (CPU mode)...") | |
| MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None) | |
| try: | |
| MODEL.optimize_for_inference() | |
| except Exception: | |
| # optimization may fail on CPU or if not implemented; ignore | |
| pass | |
| print("[INFO] Model ready.") | |
| return MODEL | |
| def index(): | |
| return send_from_directory("static", "index.html") | |
| def decode_data_url(data_url: str) -> Image.Image: | |
| if data_url.startswith("data:"): | |
| header, b64 = data_url.split(",", 1) | |
| data = base64.b64decode(b64) | |
| return Image.open(io.BytesIO(data)).convert("RGB") | |
| else: | |
| # assume plain base64 or path | |
| data = base64.b64decode(data_url) | |
| return Image.open(io.BytesIO(data)).convert("RGB") | |
| def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"): | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format=fmt) | |
| b = base64.b64encode(buf.getvalue()).decode("ascii") | |
| return f"data:image/{fmt.lower()};base64,{b}" | |
| def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.25, mask_color=(255,77,166), alpha=0.45): | |
| """ | |
| masks: either list of HxW bool arrays or numpy array (N,H,W) | |
| confidences: list of floats | |
| Returns annotated PIL image and list of kept confidences and count. | |
| """ | |
| base = pil_img.convert("RGBA") | |
| W, H = base.size | |
| # Normalize masks to N,H,W | |
| if masks is None: | |
| return base, [] | |
| if isinstance(masks, list): | |
| masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0) | |
| else: | |
| masks_arr = np.asarray(masks) | |
| # masks might be (H,W,N) -> transpose | |
| if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W: | |
| masks_arr = masks_arr.transpose(2, 0, 1) | |
| # create overlay | |
| overlay = Image.new("RGBA", (W, H), (0,0,0,0)) | |
| draw = ImageDraw.Draw(overlay) | |
| kept_confidences = [] | |
| for i in range(masks_arr.shape[0]): | |
| conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0 | |
| if conf < threshold: | |
| continue | |
| mask = masks_arr[i].astype(np.uint8) * 255 | |
| mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST) | |
| # create colored mask image | |
| color_layer = Image.new("RGBA", (W,H), mask_color + (0,)) | |
| # put alpha using mask | |
| color_layer.putalpha(mask_img.point(lambda p: int(p * alpha))) | |
| overlay = Image.alpha_composite(overlay, color_layer) | |
| kept_confidences.append(conf) | |
| # composite | |
| annotated = Image.alpha_composite(base, overlay) | |
| # add confidence text (show highest kept confidence) | |
| if len(kept_confidences) > 0: | |
| best = max(kept_confidences) | |
| draw = ImageDraw.Draw(annotated) | |
| try: | |
| # Try to use a builtin font | |
| font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30)) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| text = f"Confidence: {best:.2f}" | |
| # draw background box for text | |
| tw, th = draw.textsize(text, font=font) | |
| pad = 8 | |
| draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180)) | |
| draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255)) | |
| return annotated.convert("RGB"), kept_confidences | |
| def predict(): | |
| payload = request.get_json(force=True) | |
| if not payload or "image" not in payload: | |
| return jsonify({"error": "Missing image"}), 400 | |
| conf = float(payload.get("conf", 0.25)) | |
| # ensure model ready | |
| model = init_model() | |
| # decode image | |
| try: | |
| pil = decode_data_url(payload["image"]) | |
| except Exception as e: | |
| return jsonify({"error": f"Invalid image: {e}"}), 400 | |
| # perform prediction (model.predict expects PIL image) | |
| try: | |
| detections = model.predict(pil, threshold=0.0) # we filter using conf manually | |
| except Exception as e: | |
| return jsonify({"error": f"Inference failure: {e}"}), 500 | |
| # extract masks and confidences | |
| masks = getattr(detections, "masks", None) | |
| confidences = [] | |
| # attempt to read per-instance confidence | |
| try: | |
| confidences = [float(x) for x in getattr(detections, "confidence", [])] | |
| except Exception: | |
| # fallback: attempt attribute 'scores' or 'scores_' or generate ones | |
| confidences = [] | |
| try: | |
| confidences = [float(x) for x in getattr(detections, "scores", [])] | |
| except Exception: | |
| confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0) | |
| # overlay mask with pink-red color | |
| mask_color = (255, 77, 166) # pinkish | |
| annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45) | |
| data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG") | |
| return jsonify({ | |
| "annotated": data_url, | |
| "confidences": kept_conf, | |
| "count": len(kept_conf) | |
| }) | |
| if __name__ == "__main__": | |
| # warm up model on startup (non-blocking) | |
| try: | |
| init_model() | |
| except Exception as e: | |
| print("Model init warning:", e) | |
| app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False) | |