Spaces:
Sleeping
Sleeping
| import os # ํ์ผ/ํด๋ ํ์ | |
| import json # JSON ์ ์ฅ | |
| import random # ๋ฐ์ดํฐ ์๊ธฐ | |
| import torch # GPU ์ฌ์ฉ | |
| import re # ์ ๊ท์ (๋ฌธ์ฅ ํํฐ๋ง) | |
| from collections import defaultdict # ํด๋์ค๋ณ ๊ทธ๋ฃนํ | |
| from PIL import Image # ์ด๋ฏธ์ง ๋ก๋ | |
| from transformers import BlipProcessor, BlipForConditionalGeneration # BLIP | |
| from sentence_transformers import SentenceTransformer, util # SBERT | |
| # ---------------------- | |
| # 1. ์ค์ | |
| # ---------------------- | |
| ROOT_DIR = "data/raw" # ์ด๋ฏธ์ง ๋ฃจํธ ํด๋ (raw/ํด๋์ค/์ด๋ฏธ์ง) | |
| OUTPUT_JSON = "annotation.json" # ๊ฒฐ๊ณผ JSON ํ์ผ ์ด๋ฆ | |
| TARGET_CAPTIONS = 3 # ์ด๋ฏธ์ง๋น ์บก์ ๊ฐ์ (3 ๋๋ 5 ์ถ์ฒ) | |
| SIM_THRESHOLD = 0.85 # ๋ฌธ์ฅ ์ ์ฌ๋ ๊ธฐ์ค (๋์์๋ก ์๊ฒฉ) | |
| MIN_WORDS = 3 # ์ต์ ๋จ์ด ์ (์งง์ ๋ฌธ์ฅ ์ ๊ฑฐ) | |
| MAX_ATTEMPTS = 10 # ์บก์ ์์ฑ ์ต๋ ๋ฐ๋ณต ํ์ | |
| TRAIN_RATIO = 0.7 # train ๋น์จ | |
| VAL_RATIO = 0.15 # val ๋น์จ | |
| TEST_RATIO = 0.15 # test ๋น์จ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" # GPU ์ฌ์ฉ ์ฌ๋ถ | |
| print("device : ", device) | |
| # ---------------------- | |
| # 2. ๋ชจ๋ธ ๋ก๋ | |
| # ---------------------- | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| # ์ด๋ฏธ์ง โ ํ ํฐ ๋ณํ | |
| blip_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ).to(device) | |
| # ์บก์ ์์ฑ ๋ชจ๋ธ | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2", device=device) | |
| # ๋ฌธ์ฅ โ ๋ฒกํฐ (์ ์ฌ๋ ๊ณ์ฐ์ฉ) | |
| # ---------------------- | |
| # 3. ์บก์ ์์ฑ ํจ์ | |
| # ---------------------- | |
| def generate_captions(image, n): | |
| inputs = processor(images=image, return_tensors="pt").to(device) # ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
| outputs = blip_model.generate( | |
| **inputs, | |
| do_sample=True, # ๋ค์์ฑ ํ๋ณด (์ํ๋ง) | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.9, | |
| num_return_sequences=n, # n๊ฐ ์์ฑ | |
| max_length=30 | |
| ) | |
| # ํ ํฐ โ ๋ฌธ์์ด ๋ณํ | |
| return [ | |
| processor.decode(o, skip_special_tokens=True).strip().lower() | |
| for o in outputs | |
| ] | |
| # ---------------------- | |
| # 4. ๊ธฐ๋ณธ ํ์ง ํํฐ | |
| # ---------------------- | |
| def basic_filter(captions): | |
| filtered = [] | |
| for c in captions: | |
| words = c.split() | |
| if len(words) < MIN_WORDS: # ๋๋ฌด ์งง์ ๋ฌธ์ฅ ์ ๊ฑฐ | |
| continue | |
| if len(set(words)) < len(words) * 0.6: # ๋ฐ๋ณต ๋จ์ด ๋ง์ ๋ฌธ์ฅ ์ ๊ฑฐ | |
| continue | |
| if re.search(r"[^a-z0-9\s]", c): # ์ด์ํ ๋ฌธ์ ์ ๊ฑฐ | |
| continue | |
| filtered.append(c) | |
| return filtered | |
| # ---------------------- | |
| # 5. ํค์๋ ์ถ์ถ | |
| # ---------------------- | |
| def extract_keywords(caption): | |
| stopwords = {"a","the","on","in","at","with","and","of","to","is","are"} # ๋ถ์ฉ์ด | |
| return set([w for w in caption.split() if w not in stopwords]) # ํต์ฌ ๋จ์ด๋ง ์ถ์ถ | |
| # ---------------------- | |
| # 6. ์ ์ฌ๋ + ํค์๋ ํํฐ | |
| # ---------------------- | |
| def advanced_filter(captions): | |
| if not captions: | |
| return [] | |
| embeddings = embedder.encode(captions, convert_to_tensor=True) # ๋ฌธ์ฅ โ ๋ฒกํฐ | |
| selected = [] | |
| selected_idx = [] | |
| for i, cap in enumerate(captions): | |
| keep = True | |
| kw_i = extract_keywords(cap) | |
| for j in selected_idx: | |
| sim = util.cos_sim(embeddings[i], embeddings[j]).item() # cosine similarity | |
| if sim > SIM_THRESHOLD: # ์๋ฏธ๊ฐ ๋๋ฌด ๋น์ทํ๋ฉด ์ ๊ฑฐ | |
| keep = False | |
| break | |
| kw_j = extract_keywords(captions[j]) | |
| overlap = len(kw_i & kw_j) / max(len(kw_i), 1) | |
| if overlap > 0.7: # ํค์๋ ๋ง์ด ๊ฒน์น๋ฉด ์ ๊ฑฐ | |
| keep = False | |
| break | |
| if keep: | |
| selected.append(cap) | |
| selected_idx.append(i) | |
| return selected | |
| # ---------------------- | |
| # 7. ์บก์ ์์ฑ ๋ฃจํ | |
| # ---------------------- | |
| def get_captions(image): | |
| final_caps = [] | |
| attempts = 0 | |
| while len(final_caps) < TARGET_CAPTIONS and attempts < MAX_ATTEMPTS: | |
| needed = TARGET_CAPTIONS - len(final_caps) | |
| new_caps = generate_captions(image, needed * 3) # ๋ถ์กฑ๋ถ๋ณด๋ค ๋๋ํ ์์ฑ | |
| new_caps = basic_filter(new_caps) # 1์ฐจ ํํฐ | |
| combined = list(set(final_caps + new_caps)) # ์ค๋ณต ์ ๊ฑฐ | |
| filtered = advanced_filter(combined) # ์ ์ฌ๋ ํํฐ | |
| final_caps = filtered[:TARGET_CAPTIONS] # ๋ชฉํ ๊ฐ์ ๋ง์ถค | |
| attempts += 1 | |
| return final_caps | |
| # ---------------------- | |
| # 8. ๋ฐ์ดํฐ ์์ง | |
| # ---------------------- | |
| dataset = [] | |
| for class_name in os.listdir(ROOT_DIR): # ํด๋์ค ํด๋ ์ํ | |
| class_path = os.path.join(ROOT_DIR, class_name) | |
| if not os.path.isdir(class_path): | |
| continue | |
| for filename in os.listdir(class_path): # ์ด๋ฏธ์ง ์ํ | |
| if not filename.lower().endswith((".jpg", ".jpeg", ".png")): | |
| continue | |
| path = os.path.join(class_path, filename) | |
| image = Image.open(path).convert("RGB") # ์ด๋ฏธ์ง ๋ก๋ | |
| captions = get_captions(image) # ์บก์ ์์ฑ | |
| dataset.append({ | |
| "image": f"{class_name}/{filename}", # ์๋ ๊ฒฝ๋ก ์ ์ฅ | |
| "class": class_name, # ํด๋์ค ๋ผ๋ฒจ | |
| "captions": captions # ์บก์ ๋ฆฌ์คํธ | |
| }) | |
| print(f"\n{class_name}/{filename}") | |
| for i, c in enumerate(captions): | |
| print(f"{i+1}. {c}") | |
| # ---------------------- | |
| # 9. Stratified Split | |
| # ---------------------- | |
| class_groups = defaultdict(list) | |
| for item in dataset: | |
| class_groups[item["class"]].append(item) # ํด๋์ค๋ณ ๋ฌถ๊ธฐ | |
| train_set, val_set, test_set = [], [], [] | |
| for class_name, items in class_groups.items(): | |
| random.shuffle(items) # ํด๋์ค ๋ด๋ถ shuffle | |
| total = len(items) | |
| train_end = max(1, int(total * TRAIN_RATIO)) # ์ต์ 1๊ฐ ๋ณด์ฅ | |
| val_end = train_end + max(1, int(total * VAL_RATIO)) | |
| train_set += items[:train_end] | |
| val_set += items[train_end:val_end] | |
| test_set += items[val_end:] | |
| # split ๋ผ๋ฒจ ๋ถ์ฌ | |
| for item in train_set: | |
| item["split"] = "train" | |
| for item in val_set: | |
| item["split"] = "val" | |
| for item in test_set: | |
| item["split"] = "test" | |
| dataset = train_set + val_set + test_set # ๋ค์ ํ๋๋ก ํฉ์นจ | |
| # ---------------------- | |
| # 10. JSON ์ ์ฅ | |
| # ---------------------- | |
| with open(OUTPUT_JSON, "w", encoding="utf-8") as f: | |
| json.dump(dataset, f, indent=4, ensure_ascii=False) | |
| print(f"\n์๋ฃ: {OUTPUT_JSON} ์์ฑ๋จ") |