File size: 7,234 Bytes
a06f06c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import glob
import random
import shutil
import torch
from PIL import Image
from tqdm import tqdm
from transformers import Owlv2Processor, Owlv2ForObjectDetection

# ---------------------------------------------------------------------
# フェーズ2: トレンドのアバター服飾タグリスト(AI提案)
# ※これらのキーワード(英語)をテキストプロンプトとしてゼロショットAIに投げて枠を探させます。
# ---------------------------------------------------------------------
TEXT_PROMPTS = [
    "gothic dress",
    "cyberpunk clothes",
    "techwear",
    "maid outfit",
    "swimsuit",
    "bikini",
    "school uniform",
    "casual wear",
    "frill skirt",
    "hoodie",
    "jacket",
    "shorts",
    "twintails hair",
    "ponytail hair",
    "side ponytail hair",
    "braids hair",
    "short hair",
    "bob hair",
    "long hair",
    "straight hair",
    "curly hair",
    "cat ears",     # nekomimi
    "rabbit ears",  # usagimimi
    "animal ears",  # kemonomimi
    "sheep ears",
    "fox ears",
    "horns",
    "hair clip",
    "headset",
    "halo",
    "crown",
    "glasses",
    "goggles",
    "choker",
    "ribbon",
    "hat",
]

# BOOTHのユーザーが使うであろうタグ名と、データセット内での一貫性を保つため、
# もし見つかった場合は以下のID(または新規ID)としてyoloデータセットにマッピングします。
# ※今回は簡単のため、見つかった単語のインデックスを独自の一時的なクラスIDとして付与し、
# あとで統合データセットを作るときに data.yaml を自動更新する仕組みにします。

def perform_zero_shot_annotation(num_samples=1000, confidence_threshold=0.1):
    backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    
    # Check device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load Owlv2 model and processor (powerful open-vocabulary object detector)
    # We use owlv2-base-patch16-ensemble as it has good zero-shot accuracy
    model_id = "google/owlv2-base-patch16-ensemble"
    print(f"Loading {model_id} ...")
    try:
        processor = Owlv2Processor.from_pretrained(model_id)
        model = Owlv2ForObjectDetection.from_pretrained(model_id).to(device)
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Please ensure you have internet access to download the weights on first run.")
        return

    raw_images_dir = os.path.join(backend_dir, "scraper", "data", "raw_images")
    output_dataset_dir = os.path.join(backend_dir, "yolo_dataset", "zero_shot_auto")
    images_out_dir = os.path.join(output_dataset_dir, "images")
    labels_out_dir = os.path.join(output_dataset_dir, "labels")

    os.makedirs(images_out_dir, exist_ok=True)
    os.makedirs(labels_out_dir, exist_ok=True)

    all_images = glob.glob(os.path.join(raw_images_dir, "*.jpg"))
    if not all_images:
        print("No raw images found.")
        return

    random.shuffle(all_images)
    samples = all_images[:num_samples]
    
    print(f"Starting zero-shot annotation for {len(samples)} images against {len(TEXT_PROMPTS)} tags...")
    
    successful_annotated = 0

    for img_path in tqdm(samples):
        filename = os.path.basename(img_path)
        img_name, ext = os.path.splitext(filename)
        
        try:
            image = Image.open(img_path).convert("RGB")
            
            # owl vit expects prompts formatted slightly specifically
            # format: "a photo of an object" or just "object" -> "a photo of a [item]" works best
            texts = [[f"a photo of a {prompt}" for prompt in TEXT_PROMPTS]]
            
            inputs = processor(text=texts, images=image, return_tensors="pt").to(device)
            
            with torch.no_grad():
                outputs = model(**inputs)
            
            # Target image sizes (height, width) to rescale bounding boxes back to original size
            target_sizes = torch.tensor([image.size[::-1]]).to(device)
            results = processor.image_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=confidence_threshold)
            
            # Retrieve predictions for the first image
            i = 0
            text = texts[i]
            # No threshold filter here to see best raw scores
            results_raw = processor.image_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.0)
            boxes, scores, labels = results_raw[i]["boxes"], results_raw[i]["scores"], results_raw[i]["labels"]

            if len(scores) > 0:
                print(f"[{filename}] Max score: {scores.max().item():.3f} for label '{TEXT_PROMPTS[labels[scores.argmax()].item()]}'")
            else:
                print(f"[{filename}] No boxes generated at all.")
                
            has_detections = False
            label_lines = []
            
            # Apply our actual custom threshold
            for box, score, label in zip(boxes, scores, labels):
                if score.item() < confidence_threshold:
                    continue
                box = [round(j, 2) for j in box.tolist()]
                # box is [xmin, ymin, xmax, ymax]
                x_min, y_min, x_max, y_max = box
                
                img_w, img_h = image.size
                
                # YOLO format: cx, cy, w, h normalized
                cx = ((x_min + x_max) / 2) / img_w
                cy = ((y_min + y_max) / 2) / img_h
                w = (x_max - x_min) / img_w
                h = (y_max - y_min) / img_h
                
                # clamp to 0-1
                cx = max(0, min(cx, 1))
                cy = max(0, min(cy, 1))
                w = max(0, min(w, 1))
                h = max(0, min(h, 1))
                
                cls_id = label.item() # This corresponds to the index in TEXT_PROMPTS
                
                label_lines.append(f"{cls_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
                has_detections = True
            
            if has_detections:
                dest_img = os.path.join(images_out_dir, filename)
                shutil.copy2(img_path, dest_img)
                
                label_path = os.path.join(labels_out_dir, f"{img_name}.txt")
                with open(label_path, "w", encoding="utf-8") as f:
                    f.write("\n".join(label_lines) + "\n")
                    
                successful_annotated += 1
                
        except Exception as e:
            print(f"Error processing {filename}: {e}")

    print(f"Zero-shot auto-annotation complete! {successful_annotated} images successfully annotated.")
    
    # Save the custom classes.txt so we know what they map to
    with open(os.path.join(output_dataset_dir, "classes.txt"), "w", encoding="utf-8") as f:
        f.write("\n".join(TEXT_PROMPTS) + "\n")

if __name__ == "__main__":
    # Run on 1000 images with expanded tags for quality over quantity in this iteration
    perform_zero_shot_annotation(num_samples=1000, confidence_threshold=0.1)