Spaces:
Running
Running
File size: 7,234 Bytes
a06f06c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import os
import glob
import random
import shutil
import torch
from PIL import Image
from tqdm import tqdm
from transformers import Owlv2Processor, Owlv2ForObjectDetection
# ---------------------------------------------------------------------
# フェーズ2: トレンドのアバター服飾タグリスト(AI提案)
# ※これらのキーワード(英語)をテキストプロンプトとしてゼロショットAIに投げて枠を探させます。
# ---------------------------------------------------------------------
TEXT_PROMPTS = [
"gothic dress",
"cyberpunk clothes",
"techwear",
"maid outfit",
"swimsuit",
"bikini",
"school uniform",
"casual wear",
"frill skirt",
"hoodie",
"jacket",
"shorts",
"twintails hair",
"ponytail hair",
"side ponytail hair",
"braids hair",
"short hair",
"bob hair",
"long hair",
"straight hair",
"curly hair",
"cat ears", # nekomimi
"rabbit ears", # usagimimi
"animal ears", # kemonomimi
"sheep ears",
"fox ears",
"horns",
"hair clip",
"headset",
"halo",
"crown",
"glasses",
"goggles",
"choker",
"ribbon",
"hat",
]
# BOOTHのユーザーが使うであろうタグ名と、データセット内での一貫性を保つため、
# もし見つかった場合は以下のID(または新規ID)としてyoloデータセットにマッピングします。
# ※今回は簡単のため、見つかった単語のインデックスを独自の一時的なクラスIDとして付与し、
# あとで統合データセットを作るときに data.yaml を自動更新する仕組みにします。
def perform_zero_shot_annotation(num_samples=1000, confidence_threshold=0.1):
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load Owlv2 model and processor (powerful open-vocabulary object detector)
# We use owlv2-base-patch16-ensemble as it has good zero-shot accuracy
model_id = "google/owlv2-base-patch16-ensemble"
print(f"Loading {model_id} ...")
try:
processor = Owlv2Processor.from_pretrained(model_id)
model = Owlv2ForObjectDetection.from_pretrained(model_id).to(device)
except Exception as e:
print(f"Error loading model: {e}")
print("Please ensure you have internet access to download the weights on first run.")
return
raw_images_dir = os.path.join(backend_dir, "scraper", "data", "raw_images")
output_dataset_dir = os.path.join(backend_dir, "yolo_dataset", "zero_shot_auto")
images_out_dir = os.path.join(output_dataset_dir, "images")
labels_out_dir = os.path.join(output_dataset_dir, "labels")
os.makedirs(images_out_dir, exist_ok=True)
os.makedirs(labels_out_dir, exist_ok=True)
all_images = glob.glob(os.path.join(raw_images_dir, "*.jpg"))
if not all_images:
print("No raw images found.")
return
random.shuffle(all_images)
samples = all_images[:num_samples]
print(f"Starting zero-shot annotation for {len(samples)} images against {len(TEXT_PROMPTS)} tags...")
successful_annotated = 0
for img_path in tqdm(samples):
filename = os.path.basename(img_path)
img_name, ext = os.path.splitext(filename)
try:
image = Image.open(img_path).convert("RGB")
# owl vit expects prompts formatted slightly specifically
# format: "a photo of an object" or just "object" -> "a photo of a [item]" works best
texts = [[f"a photo of a {prompt}" for prompt in TEXT_PROMPTS]]
inputs = processor(text=texts, images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Target image sizes (height, width) to rescale bounding boxes back to original size
target_sizes = torch.tensor([image.size[::-1]]).to(device)
results = processor.image_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=confidence_threshold)
# Retrieve predictions for the first image
i = 0
text = texts[i]
# No threshold filter here to see best raw scores
results_raw = processor.image_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.0)
boxes, scores, labels = results_raw[i]["boxes"], results_raw[i]["scores"], results_raw[i]["labels"]
if len(scores) > 0:
print(f"[{filename}] Max score: {scores.max().item():.3f} for label '{TEXT_PROMPTS[labels[scores.argmax()].item()]}'")
else:
print(f"[{filename}] No boxes generated at all.")
has_detections = False
label_lines = []
# Apply our actual custom threshold
for box, score, label in zip(boxes, scores, labels):
if score.item() < confidence_threshold:
continue
box = [round(j, 2) for j in box.tolist()]
# box is [xmin, ymin, xmax, ymax]
x_min, y_min, x_max, y_max = box
img_w, img_h = image.size
# YOLO format: cx, cy, w, h normalized
cx = ((x_min + x_max) / 2) / img_w
cy = ((y_min + y_max) / 2) / img_h
w = (x_max - x_min) / img_w
h = (y_max - y_min) / img_h
# clamp to 0-1
cx = max(0, min(cx, 1))
cy = max(0, min(cy, 1))
w = max(0, min(w, 1))
h = max(0, min(h, 1))
cls_id = label.item() # This corresponds to the index in TEXT_PROMPTS
label_lines.append(f"{cls_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
has_detections = True
if has_detections:
dest_img = os.path.join(images_out_dir, filename)
shutil.copy2(img_path, dest_img)
label_path = os.path.join(labels_out_dir, f"{img_name}.txt")
with open(label_path, "w", encoding="utf-8") as f:
f.write("\n".join(label_lines) + "\n")
successful_annotated += 1
except Exception as e:
print(f"Error processing {filename}: {e}")
print(f"Zero-shot auto-annotation complete! {successful_annotated} images successfully annotated.")
# Save the custom classes.txt so we know what they map to
with open(os.path.join(output_dataset_dir, "classes.txt"), "w", encoding="utf-8") as f:
f.write("\n".join(TEXT_PROMPTS) + "\n")
if __name__ == "__main__":
# Run on 1000 images with expanded tags for quality over quantity in this iteration
perform_zero_shot_annotation(num_samples=1000, confidence_threshold=0.1)
|