booth-pic-api / backend /scripts /zero_shot_annotate.py
github-actions
Deploy to HF (clean history with LFS)
a06f06c
import os
import glob
import random
import shutil
import torch
from PIL import Image
from tqdm import tqdm
from transformers import Owlv2Processor, Owlv2ForObjectDetection
# ---------------------------------------------------------------------
# フェーズ2: トレンドのアバター服飾タグリスト(AI提案)
# ※これらのキーワード(英語)をテキストプロンプトとしてゼロショットAIに投げて枠を探させます。
# ---------------------------------------------------------------------
TEXT_PROMPTS = [
"gothic dress",
"cyberpunk clothes",
"techwear",
"maid outfit",
"swimsuit",
"bikini",
"school uniform",
"casual wear",
"frill skirt",
"hoodie",
"jacket",
"shorts",
"twintails hair",
"ponytail hair",
"side ponytail hair",
"braids hair",
"short hair",
"bob hair",
"long hair",
"straight hair",
"curly hair",
"cat ears", # nekomimi
"rabbit ears", # usagimimi
"animal ears", # kemonomimi
"sheep ears",
"fox ears",
"horns",
"hair clip",
"headset",
"halo",
"crown",
"glasses",
"goggles",
"choker",
"ribbon",
"hat",
]
# BOOTHのユーザーが使うであろうタグ名と、データセット内での一貫性を保つため、
# もし見つかった場合は以下のID(または新規ID)としてyoloデータセットにマッピングします。
# ※今回は簡単のため、見つかった単語のインデックスを独自の一時的なクラスIDとして付与し、
# あとで統合データセットを作るときに data.yaml を自動更新する仕組みにします。
def perform_zero_shot_annotation(num_samples=1000, confidence_threshold=0.1):
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load Owlv2 model and processor (powerful open-vocabulary object detector)
# We use owlv2-base-patch16-ensemble as it has good zero-shot accuracy
model_id = "google/owlv2-base-patch16-ensemble"
print(f"Loading {model_id} ...")
try:
processor = Owlv2Processor.from_pretrained(model_id)
model = Owlv2ForObjectDetection.from_pretrained(model_id).to(device)
except Exception as e:
print(f"Error loading model: {e}")
print("Please ensure you have internet access to download the weights on first run.")
return
raw_images_dir = os.path.join(backend_dir, "scraper", "data", "raw_images")
output_dataset_dir = os.path.join(backend_dir, "yolo_dataset", "zero_shot_auto")
images_out_dir = os.path.join(output_dataset_dir, "images")
labels_out_dir = os.path.join(output_dataset_dir, "labels")
os.makedirs(images_out_dir, exist_ok=True)
os.makedirs(labels_out_dir, exist_ok=True)
all_images = glob.glob(os.path.join(raw_images_dir, "*.jpg"))
if not all_images:
print("No raw images found.")
return
random.shuffle(all_images)
samples = all_images[:num_samples]
print(f"Starting zero-shot annotation for {len(samples)} images against {len(TEXT_PROMPTS)} tags...")
successful_annotated = 0
for img_path in tqdm(samples):
filename = os.path.basename(img_path)
img_name, ext = os.path.splitext(filename)
try:
image = Image.open(img_path).convert("RGB")
# owl vit expects prompts formatted slightly specifically
# format: "a photo of an object" or just "object" -> "a photo of a [item]" works best
texts = [[f"a photo of a {prompt}" for prompt in TEXT_PROMPTS]]
inputs = processor(text=texts, images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Target image sizes (height, width) to rescale bounding boxes back to original size
target_sizes = torch.tensor([image.size[::-1]]).to(device)
results = processor.image_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=confidence_threshold)
# Retrieve predictions for the first image
i = 0
text = texts[i]
# No threshold filter here to see best raw scores
results_raw = processor.image_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.0)
boxes, scores, labels = results_raw[i]["boxes"], results_raw[i]["scores"], results_raw[i]["labels"]
if len(scores) > 0:
print(f"[{filename}] Max score: {scores.max().item():.3f} for label '{TEXT_PROMPTS[labels[scores.argmax()].item()]}'")
else:
print(f"[{filename}] No boxes generated at all.")
has_detections = False
label_lines = []
# Apply our actual custom threshold
for box, score, label in zip(boxes, scores, labels):
if score.item() < confidence_threshold:
continue
box = [round(j, 2) for j in box.tolist()]
# box is [xmin, ymin, xmax, ymax]
x_min, y_min, x_max, y_max = box
img_w, img_h = image.size
# YOLO format: cx, cy, w, h normalized
cx = ((x_min + x_max) / 2) / img_w
cy = ((y_min + y_max) / 2) / img_h
w = (x_max - x_min) / img_w
h = (y_max - y_min) / img_h
# clamp to 0-1
cx = max(0, min(cx, 1))
cy = max(0, min(cy, 1))
w = max(0, min(w, 1))
h = max(0, min(h, 1))
cls_id = label.item() # This corresponds to the index in TEXT_PROMPTS
label_lines.append(f"{cls_id} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}")
has_detections = True
if has_detections:
dest_img = os.path.join(images_out_dir, filename)
shutil.copy2(img_path, dest_img)
label_path = os.path.join(labels_out_dir, f"{img_name}.txt")
with open(label_path, "w", encoding="utf-8") as f:
f.write("\n".join(label_lines) + "\n")
successful_annotated += 1
except Exception as e:
print(f"Error processing {filename}: {e}")
print(f"Zero-shot auto-annotation complete! {successful_annotated} images successfully annotated.")
# Save the custom classes.txt so we know what they map to
with open(os.path.join(output_dataset_dir, "classes.txt"), "w", encoding="utf-8") as f:
f.write("\n".join(TEXT_PROMPTS) + "\n")
if __name__ == "__main__":
# Run on 1000 images with expanded tags for quality over quantity in this iteration
perform_zero_shot_annotation(num_samples=1000, confidence_threshold=0.1)