Spaces:
Sleeping
Sleeping
| # import os | |
| # import io | |
| # import base64 | |
| # import tempfile | |
| # import threading | |
| # from PIL import Image, ImageDraw, ImageFont | |
| # import numpy as np | |
| # from flask import Flask, request, jsonify, send_from_directory | |
| # import requests | |
| # # Force CPU-only (prevents accidental GPU usage); works by hiding CUDA devices | |
| # os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| # # --- model import (ensure rfdetr package is available in requirements) --- | |
| # try: | |
| # from rfdetr import RFDETRSegPreview | |
| # except Exception as e: | |
| # raise RuntimeError("rfdetr package import failed. Make sure `rfdetr` is in requirements.") from e | |
| # app = Flask(__name__, static_folder="static", static_url_path="/") | |
| # # HF checkpoint raw resolve URL (use the 'resolve/main' raw link) | |
| # CHECKPOINT_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth" | |
| # CHECKPOINT_PATH = os.path.join("/tmp", "checkpoint_best_total.pth") | |
| # MODEL_LOCK = threading.Lock() | |
| # MODEL = None | |
| # def download_file(url: str, dst: str): | |
| # if os.path.exists(dst): | |
| # return dst | |
| # print(f"[INFO] Downloading weights from {url} ...") | |
| # r = requests.get(url, stream=True, timeout=60) | |
| # r.raise_for_status() | |
| # with open(dst, "wb") as fh: | |
| # for chunk in r.iter_content(chunk_size=8192): | |
| # if chunk: | |
| # fh.write(chunk) | |
| # print("[INFO] Download complete.") | |
| # return dst | |
| # def init_model(): | |
| # global MODEL | |
| # with MODEL_LOCK: | |
| # if MODEL is None: | |
| # # Ensure model checkpoint | |
| # try: | |
| # download_file(CHECKPOINT_URL, CHECKPOINT_PATH) | |
| # except Exception as e: | |
| # print(f"[WARN] Failed to download checkpoint: {e}. Attempting to init model without weights.") | |
| # # continue; model may fallback to default weights | |
| # print("[INFO] Loading RF-DETR model (CPU mode)...") | |
| # MODEL = RFDETRSegPreview(pretrain_weights=CHECKPOINT_PATH if os.path.exists(CHECKPOINT_PATH) else None) | |
| # try: | |
| # MODEL.optimize_for_inference() | |
| # except Exception: | |
| # # optimization may fail on CPU or if not implemented; ignore | |
| # pass | |
| # print("[INFO] Model ready.") | |
| # return MODEL | |
| # @app.route("/") | |
| # def index(): | |
| # return send_from_directory("static", "index.html") | |
| # def decode_data_url(data_url: str) -> Image.Image: | |
| # if data_url.startswith("data:"): | |
| # header, b64 = data_url.split(",", 1) | |
| # data = base64.b64decode(b64) | |
| # return Image.open(io.BytesIO(data)).convert("RGB") | |
| # else: | |
| # # assume plain base64 or path | |
| # data = base64.b64decode(data_url) | |
| # return Image.open(io.BytesIO(data)).convert("RGB") | |
| # def encode_pil_to_dataurl(pil_img: Image.Image, fmt="PNG"): | |
| # buf = io.BytesIO() | |
| # pil_img.save(buf, format=fmt) | |
| # b = base64.b64encode(buf.getvalue()).decode("ascii") | |
| # return f"data:image/{fmt.lower()};base64,{b}" | |
| # def overlay_mask_on_image(pil_img: Image.Image, masks, confidences, threshold=0.01, mask_color=(255,77,166), alpha=0.45): | |
| # """ | |
| # masks: either list of HxW bool arrays or numpy array (N,H,W) | |
| # confidences: list of floats | |
| # Returns annotated PIL image and list of kept confidences and count. | |
| # """ | |
| # base = pil_img.convert("RGBA") | |
| # W, H = base.size | |
| # # Normalize masks to N,H,W | |
| # if masks is None: | |
| # return base, [] | |
| # if isinstance(masks, list): | |
| # masks_arr = np.stack([np.asarray(m, dtype=bool) for m in masks], axis=0) | |
| # else: | |
| # masks_arr = np.asarray(masks) | |
| # # masks might be (H,W,N) -> transpose | |
| # if masks_arr.ndim == 3 and masks_arr.shape[0] == H and masks_arr.shape[1] == W: | |
| # masks_arr = masks_arr.transpose(2, 0, 1) | |
| # # create overlay | |
| # overlay = Image.new("RGBA", (W, H), (0,0,0,0)) | |
| # draw = ImageDraw.Draw(overlay) | |
| # kept_confidences = [] | |
| # for i in range(masks_arr.shape[0]): | |
| # conf = float(confidences[i]) if confidences is not None and i < len(confidences) else 1.0 | |
| # if conf < threshold: | |
| # continue | |
| # mask = masks_arr[i].astype(np.uint8) * 255 | |
| # mask_img = Image.fromarray(mask).convert("L").resize((W, H), resample=Image.NEAREST) | |
| # # create colored mask image | |
| # color_layer = Image.new("RGBA", (W,H), mask_color + (0,)) | |
| # # put alpha using mask | |
| # color_layer.putalpha(mask_img.point(lambda p: int(p * alpha))) | |
| # overlay = Image.alpha_composite(overlay, color_layer) | |
| # kept_confidences.append(conf) | |
| # # composite | |
| # annotated = Image.alpha_composite(base, overlay) | |
| # # add confidence text (show highest kept confidence) | |
| # if len(kept_confidences) > 0: | |
| # best = max(kept_confidences) | |
| # draw = ImageDraw.Draw(annotated) | |
| # try: | |
| # # Try to use a builtin font | |
| # font = ImageFont.truetype("DejaVuSans-Bold.ttf", size=max(16, W//30)) | |
| # except Exception: | |
| # font = ImageFont.load_default() | |
| # text = f"Confidence: {best:.2f}" | |
| # # draw background box for text | |
| # tw, th = draw.textsize(text, font=font) | |
| # pad = 8 | |
| # draw.rectangle([6,6, 6+tw+pad, 6+th+pad], fill=(0,0,0,180)) | |
| # draw.text((6+4,6+2), text, font=font, fill=(255,255,255,255)) | |
| # return annotated.convert("RGB"), kept_confidences | |
| # @app.route("/predict", methods=["POST"]) | |
| # def predict(): | |
| # payload = request.get_json(force=True) | |
| # if not payload or "image" not in payload: | |
| # return jsonify({"error": "Missing image"}), 400 | |
| # conf = float(payload.get("conf", 0.25)) | |
| # # ensure model ready | |
| # model = init_model() | |
| # # decode image | |
| # try: | |
| # pil = decode_data_url(payload["image"]) | |
| # except Exception as e: | |
| # return jsonify({"error": f"Invalid image: {e}"}), 400 | |
| # # perform prediction (model.predict expects PIL image) | |
| # try: | |
| # detections = model.predict(pil, threshold=0.0) # we filter using conf manually | |
| # except Exception as e: | |
| # return jsonify({"error": f"Inference failure: {e}"}), 500 | |
| # # extract masks and confidences | |
| # masks = getattr(detections, "masks", None) | |
| # confidences = [] | |
| # # attempt to read per-instance confidence | |
| # try: | |
| # confidences = [float(x) for x in getattr(detections, "confidence", [])] | |
| # except Exception: | |
| # # fallback: attempt attribute 'scores' or 'scores_' or generate ones | |
| # confidences = [] | |
| # try: | |
| # confidences = [float(x) for x in getattr(detections, "scores", [])] | |
| # except Exception: | |
| # confidences = [1.0] * (masks.shape[0] if masks is not None and hasattr(masks, "shape") and masks.shape[0] else 0) | |
| # # overlay mask with pink-red color | |
| # mask_color = (255, 77, 166) # pinkish | |
| # annotated_pil, kept_conf = overlay_mask_on_image(pil, masks, confidences, threshold=conf, mask_color=mask_color, alpha=0.45) | |
| # data_url = encode_pil_to_dataurl(annotated_pil, fmt="PNG") | |
| # return jsonify({ | |
| # "annotated": data_url, | |
| # "confidences": kept_conf, | |
| # "count": len(kept_conf) | |
| # }) | |
| # if __name__ == "__main__": | |
| # # warm up model on startup (non-blocking) | |
| # try: | |
| # init_model() | |
| # except Exception as e: | |
| # print("Model init warning:", e) | |
| # app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False) | |
| import os | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| import requests | |
| import supervision as sv | |
| from flask import Flask, request, jsonify, send_file | |
| from rfdetr import RFDETRSegPreview | |
| app = Flask(__name__) | |
| # ---- CONFIG ---- | |
| WEIGHTS_URL = "https://huggingface.co/Subh775/Segment-Tulsi-TFs-3/resolve/main/checkpoint_best_total.pth" | |
| WEIGHTS_PATH = "/tmp/checkpoint_best_total.pth" | |
| # ---- HELPERS ---- | |
| def download_file(url: str, dst: str): | |
| """Download model weights if not already cached.""" | |
| if os.path.exists(dst): | |
| print(f"[INFO] Weights already exist at {dst}") | |
| return dst | |
| print(f"[INFO] Downloading weights from {url} ...") | |
| r = requests.get(url, stream=True) | |
| r.raise_for_status() | |
| with open(dst, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print("[INFO] Download complete.") | |
| return dst | |
| def annotate_segmentation(image: Image.Image, detections: sv.Detections): | |
| """Overlay colored masks and confidence scores.""" | |
| palette = sv.ColorPalette.from_hex([ | |
| "#ff9b00", "#ff8080", "#ff66b2", "#b266ff", | |
| "#9999ff", "#3399ff", "#33ff99", "#99ff00" | |
| ]) | |
| text_scale = sv.calculate_optimal_text_scale(resolution_wh=image.size) | |
| mask_annotator = sv.MaskAnnotator(color=palette) | |
| polygon_annotator = sv.PolygonAnnotator(color=sv.Color.WHITE) | |
| label_annotator = sv.LabelAnnotator( | |
| color=palette, | |
| text_color=sv.Color.BLACK, | |
| text_scale=text_scale, | |
| text_position=sv.Position.CENTER_OF_MASS | |
| ) | |
| # Only show confidence (no class id) | |
| labels = [f"{conf:.2f}" for conf in detections.confidence] | |
| annotated = image.copy() | |
| annotated = mask_annotator.annotate(annotated, detections) | |
| annotated = polygon_annotator.annotate(annotated, detections) | |
| annotated = label_annotator.annotate(annotated, detections, labels) | |
| return annotated | |
| # ---- MODEL INITIALIZATION ---- | |
| print("[INFO] Loading RF-DETR model (CPU mode)...") | |
| download_file(WEIGHTS_URL, WEIGHTS_PATH) | |
| model = RFDETRSegPreview(pretrain_weights=WEIGHTS_PATH) | |
| try: | |
| model.optimize_for_inference() | |
| except Exception as e: | |
| print(f"[WARN] optimize_for_inference() skipped: {e}") | |
| print("[INFO] Model ready.") | |
| # ---- ROUTES ---- | |
| def home(): | |
| return jsonify({"message": "RF-DETR Segmentation API is running."}) | |
| def predict(): | |
| """Accepts an image file and returns annotated segmentation overlay.""" | |
| if "file" not in request.files: | |
| return jsonify({"error": "No file uploaded"}), 400 | |
| file = request.files["file"] | |
| image = Image.open(file.stream).convert("RGB") | |
| print(f"[INFO] Image received for inference: {file.filename}") | |
| detections = model.predict(image, threshold=0.3) | |
| print(f"[INFO] Detections found: {len(getattr(detections, 'boxes', []))}") | |
| annotated = annotate_segmentation(image, detections) | |
| buf = io.BytesIO() | |
| annotated.save(buf, format="PNG") | |
| buf.seek(0) | |
| return send_file(buf, mimetype="image/png") | |
| # if __name__ == "__main__": | |
| # app.run(host="0.0.0.0", port=7860) | |
| if __name__ == "__main__": | |
| # warm up model on startup (non-blocking) | |
| try: | |
| init_model() | |
| except Exception as e: | |
| print("Model init warning:", e) | |
| app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860)), debug=False) | |