import os import json import logging import time import base64 import requests import re from dotenv import load_dotenv # Load .env from parent directory load_dotenv(os.path.join(os.path.dirname(__file__), "../.env")) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') api_key = os.getenv("GEMINI_API_KEY") if not api_key: logging.error("GEMINI_API_KEY not found in .env") exit(1) # Define Classes CLASSES = [ "T-shirt", "Shirt", "Hoodie", "Jacket_Coat", "Skirt", "Pants_Shorts", "Dress", "Swimsuit_Underwear", "Uniform", "Ponytail", "Twin-tails", "Hair_Other", "Shoes_Boots", "Accessory", "Other_Outfit" ] CLASS_TO_ID = {cls.lower(): i for i, cls in enumerate(CLASSES)} def get_yolo_format(box, img_width, img_height): ymin, xmin, ymax, xmax = box ymin, xmin, ymax, xmax = ymin/1000, xmin/1000, ymax/1000, xmax/1000 x_center = (xmin + xmax) / 2 y_center = (ymin + ymax) / 2 width = xmax - xmin height = ymax - ymin return x_center, y_center, width, height def auto_annotate(): dataset_dir = "data/yolo_dataset/images" if not os.path.exists(dataset_dir): logging.error(f"Directory {dataset_dir} not found.") return # Ensure classes.txt exists and matches our list classes_path = os.path.join(dataset_dir, "classes.txt") with open(classes_path, "w", encoding='utf-8') as f: f.write("\n".join(CLASSES) + "\n") image_files = [f for f in os.listdir(dataset_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))] logging.info(f"Processing {len(image_files)} images for Multi-Class Annotation via REST API...") # Use Gemini 2.5 Pro for best reasoning url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro:generateContent?key={api_key}" for i, img_file in enumerate(image_files): img_path = os.path.join(dataset_dir, img_file) label_file = os.path.join(dataset_dir, os.path.splitext(img_file)[0] + ".txt") if os.path.exists(label_file) and img_file != "classes.txt": continue meta_file = os.path.join(dataset_dir, os.path.splitext(img_file)[0] + "_meta.txt") context_info = "" if os.path.exists(meta_file): with open(meta_file, 'r', encoding='utf-8') as f: context_info = f.read() logging.info(f"[{i+1}/{len(image_files)}] Analyzing Multi-Class: {img_file}") try: with open(img_path, "rb") as f: img_data = base64.b64encode(f.read()).decode('utf-8') prompt_text = ( "You are an expert computer vision system ASSISTED by product metadata. Analyze this image to find ALL distinct 3D character clothing items, hairstyles, and accessories.\n" f"=== METADATA (Title, Description, Tags) ===\n{context_info}\n===========================================\n" "CRITICAL INSTRUCTION: Read the METADATA above first. Use it to understand EXACTLY what the creator is selling (e.g., if it says 'parka', look for a Hoodie. If it says 'skirt', look for a Skirt. If it mentions a hairstyle like 'twin tails', look for Twin-tails).\n\n" f"You MUST classify each found item STRICTLY into one of the following exact categories: {', '.join(CLASSES)}.\n" "First, think step-by-step: 'Based on the metadata, this product is X. In the image, X is located at the top/bottom. I also see Y and Z.'\n" "Then, for EACH valid item you found from the class list, provide its class and a STRICTLY TIGHT bounding box that wraps ONLY that specific item.\n" "Output the final results in the following format exactly, one item per line:\n" "ITEM: [Class Name], BOX: [ymin, xmin, ymax, xmax]\n" "where coordinates are normalized integers from 0 to 1000. Example:\n" "ITEM: Skirt, BOX: [500, 200, 850, 800]\n" "ITEM: T-shirt, BOX: [200, 220, 520, 780]\n" "Ensure the class name exactly matches the list provided." ) payload = { "contents": [{ "parts": [ {"text": prompt_text}, {"inline_data": {"mime_type": "image/jpeg", "data": img_data}} ] }] } headers = {'Content-Type': 'application/json'} response = requests.post(url, headers=headers, data=json.dumps(payload)) res_json = response.json() if "candidates" in res_json: text = res_json["candidates"][0]["content"]["parts"][0]["text"].strip() # Regex to find all ITEM: ..., BOX: [...] matches matches = re.finditer(r'ITEM:\s*([A-Za-z0-9_-]+)\s*,\s*BOX:\s*\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]', text) boxes_found = [] from PIL import Image with Image.open(img_path) as img: w, h = img.size for match in matches: cls_name = match.group(1).lower() if cls_name in CLASS_TO_ID: box = [int(n) for n in match.groups()[1:]] xc, yc, rw, rh = get_yolo_format(box, w, h) boxes_found.append(f"{CLASS_TO_ID[cls_name]} {xc:.6f} {yc:.6f} {rw:.6f} {rh:.6f}") else: logging.warning(f" Unknown class predicted: {match.group(1)}") if boxes_found: with open(label_file, 'w', encoding='utf-8') as lf: lf.write("\n".join(boxes_found) + "\n") logging.info(f" SUCCESS: {img_file} => {len(boxes_found)} items mapped.") else: logging.warning(f" PARSE FAILED or no items: {text}") else: logging.error(f" API ERROR: {json.dumps(res_json)}") time.sleep(1.0) # Rate limit safety except Exception as e: logging.error(f" Error: {e}") if __name__ == "__main__": auto_annotate()