Spaces:
Running
Running
| import gradio as gr | |
| from ultralytics import YOLO | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| import logging | |
| import traceback | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| logger.info("YOLO ๋ก๋ฉ...") | |
| yolo_model = YOLO("best.pt") | |
| logger.info("Fashion-CLIP ๋ก๋ฉ...") | |
| clip_model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") | |
| clip_processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
| clip_model.eval() | |
| logger.info("๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
| # โโโ ํ์ง ํ์ง ํ๋ผ๋ฏธํฐ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ์ ๋ขฐ๋ ์๊ณ๊ฐ: ์ด๋ณด๋ค ๋ฎ์ ํ์ง ๊ฒฐ๊ณผ๋ ๋ ธ์ด์ฆ๋ก ๊ฐ์ฃผํ์ฌ ์ ๊ฑฐ | |
| CONF_THRESHOLD = 0.35 | |
| # ์ต์ ๋ฉด์ ๋น์จ: ์ด๋ฏธ์ง ์ ์ฒด ๋ฉด์ ๋๋น ํ์ง ๋ฐ์ค ๋ฉด์ ์ด ์ด ๋น์จ ๋ฏธ๋ง์ด๋ฉด ์ ๊ฑฐ | |
| # (์: 0.04 = ์ ์ฒด ์ด๋ฏธ์ง์ 4% ๋ฏธ๋ง์ธ ๋ฐ์ค๋ ๋๋ฌด ์์ ์ ๋ขฐํ ์ ์์) | |
| MIN_AREA_RATIO = 0.04 | |
| # IOU ์๊ณ๊ฐ: YOLO ๋ด๋ถ NMS์์ ์ค๋ณต ๋ฐ์ค ์ ๊ฑฐ ๊ธฐ์ค (๋ฎ์์๋ก ์๊ฒฉ) | |
| IOU_THRESHOLD = 0.45 | |
| def _containment_ratio(inner: dict, outer: dict) -> float: | |
| """ | |
| inner ๋ฐ์ค๊ฐ outer ๋ฐ์ค ๋ด๋ถ์ ํฌํจ๋ ๋น์จ(0.0~1.0)์ ๋ฐํ. | |
| 1.0์ด๋ฉด inner๊ฐ outer์ ์์ ํ ํฌํจ๋จ. | |
| inner ๋ฉด์ ๋๋น ๊ต์งํฉ ๋ฉด์ ์ ๋น์จ๋ก ๊ณ์ฐ. | |
| """ | |
| ix1 = max(inner["x1"], outer["x1"]) | |
| iy1 = max(inner["y1"], outer["y1"]) | |
| ix2 = min(inner["x2"], outer["x2"]) | |
| iy2 = min(inner["y2"], outer["y2"]) | |
| inter_w = max(0.0, ix2 - ix1) | |
| inter_h = max(0.0, iy2 - iy1) | |
| inter_area = inter_w * inter_h | |
| inner_area = max(1.0, (inner["x2"] - inner["x1"]) * (inner["y2"] - inner["y1"])) | |
| return inter_area / inner_area | |
| def _select_best_boxes(raw_boxes: list[dict], img_w: int, img_h: int) -> list[dict]: | |
| """ | |
| YOLO๊ฐ ๋ฐํํ ๋ชจ๋ ๋ฐ์ค์์ ํ์ง ๋ฎ์ ๋ฐ์ค๋ฅผ ์ ๊ฑฐํ๊ณ | |
| ์นดํ ๊ณ ๋ฆฌ(๋ ์ด๋ธ)๋ณ๋ก ์ ๋ขฐ๋๊ฐ ๊ฐ์ฅ ๋์ ๋ฐ์ค 1๊ฐ์ฉ๋ง ์ ํ. | |
| ์ถ๊ฐ ํ์ฒ๋ฆฌ: | |
| - ํฌํจ ๊ด๊ณ ํํฐ: ํ ๋ฐ์ค๊ฐ ๋ค๋ฅธ ๋ฐ์ค ์์ ํฌ๊ฒ ํฌํจ๋๋ฉด ์ ๊ฑฐ | |
| ์) Outer ๋ฐ์ค ์์ Top ๋ฐ์ค๊ฐ 80% ์ด์ ๋ค์ด ์์ผ๋ฉด Top ์ ๊ฑฐ | |
| - Bottom ๋ฐ์ค ํ์ฅ: ํ์ ๋ฐ์ค๊ฐ ์ด๋ฏธ์ง ํ๋จ์ ๋ฟ์ง ์์ผ๋ฉด ์๋๋ก ๋๋ ค ๋ฐ์ง ์ ์ฒด ํฌํจ | |
| Args: | |
| raw_boxes: YOLO์์ ๋ฐํ๋ ์๋ณธ ๋ฐ์ค ๋ชฉ๋ก | |
| img_w: ์๋ณธ ์ด๋ฏธ์ง ๊ฐ๋ก ํฝ์ | |
| img_h: ์๋ณธ ์ด๋ฏธ์ง ์ธ๋ก ํฝ์ | |
| Returns: | |
| ์ ์ ๋ ๋ฐ์ค ๋ชฉ๋ก | |
| """ | |
| img_area = img_w * img_h | |
| if img_area <= 0: | |
| return raw_boxes | |
| # 1๋จ๊ณ: ์ ๋ขฐ๋ + ์ต์ ๋ฉด์ ๊ธฐ์ค์ผ๋ก ๋ ธ์ด์ฆ ๋ฐ์ค ์ ๊ฑฐ | |
| # Bottom ์ ์ฉ ์๊ณ๊ฐ: ๊ฐ๋ ค์ง ๋ค๋ฆฌยท์ค๋ฅธ์ชฝ ๋ฐ์ง ๋ฑ์ด ๋ฎ์ ์ ๋ขฐ๋๋ก ํ์ง๋ ์ ์์ผ๋ฏ๋ก | |
| # ๋ค๋ฅธ ์นดํ ๊ณ ๋ฆฌ(0.35)๋ณด๋ค ๋ฎ์ 0.20์ ์ ์ฉํด 2๊ฐ ๋ฐ์ค๊ฐ Union๊น์ง ์ด์๋จ๋๋ก ํจ. | |
| BOTTOM_CONF_THRESHOLD = 0.20 | |
| filtered = [] | |
| for box in raw_boxes: | |
| label = box.get("label", "").lower() | |
| conf = box.get("confidence", 0.0) | |
| # ํ์๋ ๋ฎ์ ์ ๋ขฐ๋ ์๊ณ๊ฐ ์ ์ฉ, ๋๋จธ์ง๋ ๊ธฐ๋ณธ ์๊ณ๊ฐ | |
| threshold = BOTTOM_CONF_THRESHOLD if label in ("bottom", "ํ์") else CONF_THRESHOLD | |
| if conf < threshold: | |
| logger.debug(f"์ ๋ขฐ๋ ๋ฏธ๋ฌ ๋ฐ์ค ์ ๊ฑฐ: label={label}, conf={conf:.3f} (๊ธฐ์ค={threshold:.2f})") | |
| continue | |
| x1, y1, x2, y2 = box["x1"], box["y1"], box["x2"], box["y2"] | |
| box_area = max(0.0, (x2 - x1)) * max(0.0, (y2 - y1)) | |
| area_ratio = box_area / img_area | |
| if area_ratio < MIN_AREA_RATIO: | |
| logger.debug( | |
| f"๋ฉด์ ๋ฏธ๋ฌ ๋ฐ์ค ์ ๊ฑฐ: label={label}, " | |
| f"area_ratio={area_ratio:.3f} (<{MIN_AREA_RATIO})" | |
| ) | |
| continue | |
| filtered.append(box) | |
| # 2๋จ๊ณ: ์นดํ ๊ณ ๋ฆฌ(๋ ์ด๋ธ)๋ณ ๋ฐ์ค๋ฅผ Union์ผ๋ก ๋ณํฉ | |
| # ๋์ผ ์นดํ ๊ณ ๋ฆฌ(์: Bottom)๊ฐ ์ฌ๋ฌ ๊ฐ ํ์ง๋ ๊ฒฝ์ฐ, | |
| # ๊ฐ์ฅ ์ ๋ขฐ๋ ๋์ ๊ฒ๋ง ๊ณ ๋ฅด์ง ์๊ณ ๋ชจ๋ ๋ฐ์ค๋ฅผ ํฉ์งํฉ(Union)์ผ๋ก ํฉ์นจ. | |
| # โ ๋ฐ์ง๊ฐ ๋ ๋ฐ์ค๋ก ๋๋์ด ํ์ง๋ ๋ ๋ ๋ฐ์ค๋ฅผ ํฉ์ณ ์ ์ฒด ๋ฐ์ง ์์ญ ์ปค๋ฒ | |
| union_by_label: dict[str, dict] = {} | |
| for box in filtered: | |
| label = box.get("label", "unknown") | |
| if label not in union_by_label: | |
| # ์ฒซ ๋ฒ์งธ ๋ฐ์ค๋ ๊ทธ๋๋ก ๋ณต์ฌ (์๋ณธ ๋ณ๊ฒฝ ๋ฐฉ์ง) | |
| union_by_label[label] = dict(box) | |
| else: | |
| prev = union_by_label[label] | |
| # ๊ธฐ์กด ๋ฐ์ค์ ํ์ฌ ๋ฐ์ค์ ํฉ์งํฉ(Union) ๊ณ์ฐ | |
| prev["x1"] = min(prev["x1"], box["x1"]) | |
| prev["y1"] = min(prev["y1"], box["y1"]) | |
| prev["x2"] = max(prev["x2"], box["x2"]) | |
| prev["y2"] = max(prev["y2"], box["y2"]) | |
| # ์ ๋ขฐ๋๋ ์ต๋๊ฐ ์ ์ง | |
| prev["confidence"] = max(prev.get("confidence", 0.0), box.get("confidence", 0.0)) | |
| logger.info( | |
| f"๋ฐ์ค Union ๋ณํฉ: '{label}' ๋ฐ์ค 2๊ฐ ํฉ์ฐ " | |
| f"โ ({prev['x1']:.0f},{prev['y1']:.0f})-({prev['x2']:.0f},{prev['y2']:.0f})" | |
| ) | |
| # 3๋จ๊ณ: ํฌํจ ๊ด๊ณ ํํฐ | |
| # ๋ฐ์ค A๊ฐ ๋ฐ์ค B ์์ CONTAINMENT_THRESHOLD ์ด์ ํฌํจ๋๋ฉด A๋ฅผ ์ ๊ฑฐ | |
| # ์) Outer(ํฐ ๋ฐ์ค) ์์ Top(์์ ๋ฐ์ค)์ด 80%+ ํฌํจ โ Top ์ ๊ฑฐ | |
| # ๋จ, Bottom์ Outer ์๋์ชฝ์ ๋ณ๋ ์กด์ฌํ๋ฏ๋ก ๋ค๋ฅธ ๊ธฐ์ค ์ ์ฉ | |
| CONTAINMENT_THRESHOLD = 0.75 # inner ๋ฐ์ค ๋ฉด์ ์ ์ด ๋น์จ ์ด์์ด outer ์์ ์์ผ๋ฉด ์ ๊ฑฐ | |
| candidates = list(union_by_label.values()) | |
| to_remove = set() | |
| for i, box_a in enumerate(candidates): | |
| for j, box_b in enumerate(candidates): | |
| if i == j: | |
| continue | |
| label_a = box_a.get("label", "").lower() | |
| label_b = box_b.get("label", "").lower() | |
| # bottom์ outer์ ์์ง์ผ๋ก ๋ถ๋ฆฌ๋๋ฏ๋ก ํฌํจ ํ๋จ์์ ์ ์ธ | |
| if label_a in ("bottom", "ํ์") or label_b in ("bottom", "ํ์"): | |
| continue | |
| ratio = _containment_ratio(box_a, box_b) | |
| if ratio >= CONTAINMENT_THRESHOLD: | |
| # box_a๊ฐ box_b ์์ ํฌ๊ฒ ํฌํจ๋จ โ box_a ๋ฉด์ ์ด box_b๋ณด๋ค ์์ผ๋ฉด ์ ๊ฑฐ | |
| area_a = (box_a["x2"] - box_a["x1"]) * (box_a["y2"] - box_a["y1"]) | |
| area_b = (box_b["x2"] - box_b["x1"]) * (box_b["y2"] - box_b["y1"]) | |
| if area_a < area_b: | |
| to_remove.add(label_a) | |
| logger.info( | |
| f"ํฌํจ ๊ด๊ณ ํํฐ: '{label_a}' ๋ฐ์ค๊ฐ '{label_b}' ๋ฐ์ค์ " | |
| f"{ratio:.0%} ํฌํจ โ '{label_a}' ์ ๊ฑฐ" | |
| ) | |
| result = [b for b in candidates if b.get("label", "").lower() not in to_remove] | |
| logger.info( | |
| f"๋ฐ์ค ํํฐ๋ง: ์๋ณธ {len(raw_boxes)}๊ฐ โ " | |
| f"์ ๋ขฐ๋/๋ฉด์ ํํฐ ํ {len(filtered)}๊ฐ โ " | |
| f"Union ๋ณํฉ ํ {len(union_by_label)}๊ฐ โ " | |
| f"ํฌํจ ๊ด๊ณ ํํฐ ํ {len(result)}๊ฐ" | |
| ) | |
| return result | |
| def _get_best_crop(pil_img: Image.Image, boxes: list[dict]) -> Image.Image: | |
| """ | |
| ํํฐ๋ง๋ ๋ฐ์ค ์ค ์ ๋ขฐ๋๊ฐ ๊ฐ์ฅ ๋์ ๋ฐ์ค ์์ญ์ ํฌ๋กญํ์ฌ ๋ฐํ. | |
| ๋ฐ์ค๊ฐ ์์ผ๋ฉด ์๋ณธ ์ด๋ฏธ์ง๋ฅผ ๊ทธ๋๋ก ๋ฐํ. | |
| ํฌ๋กญ ์ด๋ฏธ์ง๋ก CLIP ์๋ฒ ๋ฉ์ ์์ฑํ๋ฉด | |
| ์ ์ฒด ์ด๋ฏธ์ง ์๋ฒ ๋ฉ๋ณด๋ค ํจ์ ์์ดํ ์ ์ง์ค๋ ๋ ์ ํํ ๋ฒกํฐ๋ฅผ ์ป์ ์ ์์. | |
| """ | |
| if not boxes: | |
| return pil_img | |
| # ์ ๋ขฐ๋๊ฐ ๊ฐ์ฅ ๋์ ๋ฐ์ค ์ ํ | |
| best = max(boxes, key=lambda b: b.get("confidence", 0.0)) | |
| x1 = int(best["x1"]) | |
| y1 = int(best["y1"]) | |
| x2 = int(best["x2"]) | |
| y2 = int(best["y2"]) | |
| # ์๋ณธ ์ด๋ฏธ์ง ๋ฒ์ ํด๋จํ | |
| w, h = pil_img.size | |
| x1 = max(0, min(x1, w - 1)) | |
| x2 = max(x1 + 1, min(x2, w)) | |
| y1 = max(0, min(y1, h - 1)) | |
| y2 = max(y1 + 1, min(y2, h)) | |
| try: | |
| cropped = pil_img.crop((x1, y1, x2, y2)) | |
| logger.info( | |
| f"CLIP ์๋ฒ ๋ฉ์ฉ ํฌ๋กญ ์ด๋ฏธ์ง: " | |
| f"label={best.get('label')}, conf={best.get('confidence', 0):.3f}, " | |
| f"crop=({x1},{y1},{x2},{y2})" | |
| ) | |
| return cropped | |
| except Exception as e: | |
| logger.warning(f"ํฌ๋กญ ์คํจ, ์๋ณธ ์ฌ์ฉ: {e}") | |
| return pil_img | |
| def predict(image): | |
| try: | |
| if image is None: | |
| return { | |
| "status": "error", | |
| "error_message": "No image provided", | |
| "embedding": None, | |
| "boxes": [], | |
| "label": "unknown", | |
| "category": None, | |
| } | |
| # โโ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if isinstance(image, str): | |
| pil_img = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| pil_img = Image.fromarray(image).convert("RGB") | |
| elif isinstance(image, Image.Image): | |
| pil_img = image.convert("RGB") | |
| else: | |
| pil_img = Image.open(str(image)).convert("RGB") | |
| img_w, img_h = pil_img.size | |
| # โโ YOLO ํ์ง โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # conf=0.10: ๋ฎ๊ฒ ์ค์ ํ์ฌ ๊ฐ๋ ค์ง/์ฝํ๊ฒ ํ์ง๋ ๋ฐ์ค๋ ์ผ๋จ ์์ง | |
| # ์ค์ ์ ๋ขฐ๋ ํํฐ๋ _select_best_boxes์์ ์นดํ ๊ณ ๋ฆฌ๋ณ๋ก ์ฒ๋ฆฌ | |
| # iou=0.80: NMS๋ฅผ ๋์จํ๊ฒ โ ๊ฒน์น๋ Bottom ๋ฐ์ค 2๊ฐ๊ฐ ๋ชจ๋ ์ด์๋จ์ | |
| # (์: ์ผ์ชฝ ๋ค๋ฆฌ ๋ฐ์ค + ์ค๋ฅธ์ชฝ ๋ค๋ฆฌ ๋ฐ์ค๊ฐ 0.80 ๋ฏธ๋ง์ผ๋ก ๊ฒน์นจ) | |
| # ๋ ๋ฐ์ค๊ฐ ๋ชจ๋ ๋๋ฌํด์ผ Union์ผ๋ก ์ ์ฒด ๋ฐ์ง ์์ญ ํฉ์ฐ ๊ฐ๋ฅ | |
| results = yolo_model.predict( | |
| source=pil_img, | |
| conf=0.10, # ๋ฎ๊ฒ: ์ฝํ ํ์ง๋ ์์ง (์ดํ _select_best_boxes์์ ํํฐ) | |
| iou=0.80, # ๋์จํ NMS: ๊ฐ์ ์นดํ ๊ณ ๋ฆฌ ๋ฐ์ค 2๊ฐ๊ฐ ๋ชจ๋ ์ด์๋จ๋๋ก | |
| save=False, | |
| verbose=False, | |
| ) | |
| raw_boxes = [] | |
| if results and len(results) > 0: | |
| for result in results: | |
| if result.boxes: | |
| for box in result.boxes: | |
| x1, y1, x2, y2 = box.xyxy[0].tolist() | |
| conf = float(box.conf[0]) if box.conf is not None else 0 | |
| cls = int(box.cls[0]) if box.cls is not None else 0 | |
| label = ( | |
| result.names.get(cls, "unknown") | |
| if hasattr(result, "names") | |
| else "unknown" | |
| ) | |
| raw_boxes.append( | |
| { | |
| "x1": x1, | |
| "y1": y1, | |
| "x2": x2, | |
| "y2": y2, | |
| "confidence": conf, | |
| "label": label, | |
| } | |
| ) | |
| # โโ ๋ฐ์ค ํ์ฒ๋ฆฌ: ๋ ธ์ด์ฆ ์ ๊ฑฐ + ์นดํ ๊ณ ๋ฆฌ๋ณ ์ต๊ณ ์ ๋ขฐ๋ 1๊ฐ ์ ํ โโโโโโโโ | |
| filtered_boxes = _select_best_boxes(raw_boxes, img_w, img_h) | |
| # ๋ํ ์นดํ ๊ณ ๋ฆฌ: ์ ๋ขฐ๋ ๊ฐ์ฅ ๋์ ๋ฐ์ค์ ๋ ์ด๋ธ | |
| detected_category = None | |
| if filtered_boxes: | |
| best_box = max(filtered_boxes, key=lambda b: b.get("confidence", 0.0)) | |
| label = best_box.get("label", "") | |
| if label and label != "unknown": | |
| detected_category = label | |
| # โโ CLIP ์๋ฒ ๋ฉ (512d) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ํฌ๋กญ ์ด๋ฏธ์ง๋ก ์๋ฒ ๋ฉ ์์ฑ (๋ฐ์ค๊ฐ ์์ผ๋ฉด ์ํ ์์ญ๋ง ํฌ๋กญ) | |
| embed_img = _get_best_crop(pil_img, filtered_boxes) | |
| inputs = clip_processor(images=embed_img, return_tensors="pt") | |
| with torch.no_grad(): | |
| # vision_model โ visual_projection ์์๋ก ๋ช ์ ํธ์ถ | |
| vision_outputs = clip_model.vision_model(**inputs) | |
| features = clip_model.visual_projection(vision_outputs.pooler_output) | |
| # L2 ์ ๊ทํ (์ฝ์ฌ์ธ ์ ์ฌ๋ ์ต์ ํ) | |
| embedding = torch.nn.functional.normalize(features, p=2, dim=1) | |
| embedding_list = embedding[0].cpu().tolist() | |
| logger.info( | |
| f"์๋ฒ ๋ฉ ์์ฑ ์๋ฃ: dim={len(embedding_list)}, " | |
| f"filtered_boxes={len(filtered_boxes)}, " | |
| f"category={detected_category}" | |
| ) | |
| return { | |
| "status": "success", | |
| "embedding": embedding_list, # 512d ๋ฒกํฐ | |
| "boxes": filtered_boxes, | |
| "label": detected_category if detected_category else "full_image", | |
| "category": detected_category, | |
| } | |
| except Exception as e: | |
| err_msg = traceback.format_exc() | |
| logger.error(f"์ถ๋ก ์ค ์์ธ ๋ฐ์: {err_msg}") | |
| return { | |
| "status": "error", | |
| "error_message": str(e), | |
| "traceback": err_msg, | |
| "embedding": None, | |
| "boxes": [], | |
| "label": "unknown", | |
| "category": None, | |
| } | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.JSON(), | |
| ) | |
| demo.launch(show_error=True) | |