Mini-ImageNet / src /caption /generate_captions_blip.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
Raw
History Blame Contribute Delete
7.06 kB
import os # ํŒŒ์ผ/ํด๋” ํƒ์ƒ‰
import json # JSON ์ €์žฅ
import random # ๋ฐ์ดํ„ฐ ์„ž๊ธฐ
import torch # GPU ์‚ฌ์šฉ
import re # ์ •๊ทœ์‹ (๋ฌธ์žฅ ํ•„ํ„ฐ๋ง)
from collections import defaultdict # ํด๋ž˜์Šค๋ณ„ ๊ทธ๋ฃนํ™”
from PIL import Image # ์ด๋ฏธ์ง€ ๋กœ๋“œ
from transformers import BlipProcessor, BlipForConditionalGeneration # BLIP
from sentence_transformers import SentenceTransformer, util # SBERT
# ----------------------
# 1. ์„ค์ •
# ----------------------
ROOT_DIR = "data/raw" # ์ด๋ฏธ์ง€ ๋ฃจํŠธ ํด๋” (raw/ํด๋ž˜์Šค/์ด๋ฏธ์ง€)
OUTPUT_JSON = "annotation.json" # ๊ฒฐ๊ณผ JSON ํŒŒ์ผ ์ด๋ฆ„
TARGET_CAPTIONS = 3 # ์ด๋ฏธ์ง€๋‹น ์บก์…˜ ๊ฐœ์ˆ˜ (3 ๋˜๋Š” 5 ์ถ”์ฒœ)
SIM_THRESHOLD = 0.85 # ๋ฌธ์žฅ ์œ ์‚ฌ๋„ ๊ธฐ์ค€ (๋†’์„์ˆ˜๋ก ์—„๊ฒฉ)
MIN_WORDS = 3 # ์ตœ์†Œ ๋‹จ์–ด ์ˆ˜ (์งง์€ ๋ฌธ์žฅ ์ œ๊ฑฐ)
MAX_ATTEMPTS = 10 # ์บก์…˜ ์ƒ์„ฑ ์ตœ๋Œ€ ๋ฐ˜๋ณต ํšŸ์ˆ˜
TRAIN_RATIO = 0.7 # train ๋น„์œจ
VAL_RATIO = 0.15 # val ๋น„์œจ
TEST_RATIO = 0.15 # test ๋น„์œจ
device = "cuda" if torch.cuda.is_available() else "cpu" # GPU ์‚ฌ์šฉ ์—ฌ๋ถ€
print("device : ", device)
# ----------------------
# 2. ๋ชจ๋ธ ๋กœ๋“œ
# ----------------------
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# ์ด๋ฏธ์ง€ โ†’ ํ† ํฐ ๋ณ€ํ™˜
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
).to(device)
# ์บก์…˜ ์ƒ์„ฑ ๋ชจ๋ธ
embedder = SentenceTransformer("all-MiniLM-L6-v2", device=device)
# ๋ฌธ์žฅ โ†’ ๋ฒกํ„ฐ (์œ ์‚ฌ๋„ ๊ณ„์‚ฐ์šฉ)
# ----------------------
# 3. ์บก์…˜ ์ƒ์„ฑ ํ•จ์ˆ˜
# ----------------------
def generate_captions(image, n):
inputs = processor(images=image, return_tensors="pt").to(device) # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
outputs = blip_model.generate(
**inputs,
do_sample=True, # ๋‹ค์–‘์„ฑ ํ™•๋ณด (์ƒ˜ํ”Œ๋ง)
top_k=50,
top_p=0.95,
temperature=0.9,
num_return_sequences=n, # n๊ฐœ ์ƒ์„ฑ
max_length=30
)
# ํ† ํฐ โ†’ ๋ฌธ์ž์—ด ๋ณ€ํ™˜
return [
processor.decode(o, skip_special_tokens=True).strip().lower()
for o in outputs
]
# ----------------------
# 4. ๊ธฐ๋ณธ ํ’ˆ์งˆ ํ•„ํ„ฐ
# ----------------------
def basic_filter(captions):
filtered = []
for c in captions:
words = c.split()
if len(words) < MIN_WORDS: # ๋„ˆ๋ฌด ์งง์€ ๋ฌธ์žฅ ์ œ๊ฑฐ
continue
if len(set(words)) < len(words) * 0.6: # ๋ฐ˜๋ณต ๋‹จ์–ด ๋งŽ์€ ๋ฌธ์žฅ ์ œ๊ฑฐ
continue
if re.search(r"[^a-z0-9\s]", c): # ์ด์ƒํ•œ ๋ฌธ์ž ์ œ๊ฑฐ
continue
filtered.append(c)
return filtered
# ----------------------
# 5. ํ‚ค์›Œ๋“œ ์ถ”์ถœ
# ----------------------
def extract_keywords(caption):
stopwords = {"a","the","on","in","at","with","and","of","to","is","are"} # ๋ถˆ์šฉ์–ด
return set([w for w in caption.split() if w not in stopwords]) # ํ•ต์‹ฌ ๋‹จ์–ด๋งŒ ์ถ”์ถœ
# ----------------------
# 6. ์œ ์‚ฌ๋„ + ํ‚ค์›Œ๋“œ ํ•„ํ„ฐ
# ----------------------
def advanced_filter(captions):
if not captions:
return []
embeddings = embedder.encode(captions, convert_to_tensor=True) # ๋ฌธ์žฅ โ†’ ๋ฒกํ„ฐ
selected = []
selected_idx = []
for i, cap in enumerate(captions):
keep = True
kw_i = extract_keywords(cap)
for j in selected_idx:
sim = util.cos_sim(embeddings[i], embeddings[j]).item() # cosine similarity
if sim > SIM_THRESHOLD: # ์˜๋ฏธ๊ฐ€ ๋„ˆ๋ฌด ๋น„์Šทํ•˜๋ฉด ์ œ๊ฑฐ
keep = False
break
kw_j = extract_keywords(captions[j])
overlap = len(kw_i & kw_j) / max(len(kw_i), 1)
if overlap > 0.7: # ํ‚ค์›Œ๋“œ ๋งŽ์ด ๊ฒน์น˜๋ฉด ์ œ๊ฑฐ
keep = False
break
if keep:
selected.append(cap)
selected_idx.append(i)
return selected
# ----------------------
# 7. ์บก์…˜ ์ƒ์„ฑ ๋ฃจํ”„
# ----------------------
def get_captions(image):
final_caps = []
attempts = 0
while len(final_caps) < TARGET_CAPTIONS and attempts < MAX_ATTEMPTS:
needed = TARGET_CAPTIONS - len(final_caps)
new_caps = generate_captions(image, needed * 3) # ๋ถ€์กฑ๋ถ„๋ณด๋‹ค ๋„‰๋„‰ํžˆ ์ƒ์„ฑ
new_caps = basic_filter(new_caps) # 1์ฐจ ํ•„ํ„ฐ
combined = list(set(final_caps + new_caps)) # ์ค‘๋ณต ์ œ๊ฑฐ
filtered = advanced_filter(combined) # ์œ ์‚ฌ๋„ ํ•„ํ„ฐ
final_caps = filtered[:TARGET_CAPTIONS] # ๋ชฉํ‘œ ๊ฐœ์ˆ˜ ๋งž์ถค
attempts += 1
return final_caps
# ----------------------
# 8. ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘
# ----------------------
dataset = []
for class_name in os.listdir(ROOT_DIR): # ํด๋ž˜์Šค ํด๋” ์ˆœํšŒ
class_path = os.path.join(ROOT_DIR, class_name)
if not os.path.isdir(class_path):
continue
for filename in os.listdir(class_path): # ์ด๋ฏธ์ง€ ์ˆœํšŒ
if not filename.lower().endswith((".jpg", ".jpeg", ".png")):
continue
path = os.path.join(class_path, filename)
image = Image.open(path).convert("RGB") # ์ด๋ฏธ์ง€ ๋กœ๋“œ
captions = get_captions(image) # ์บก์…˜ ์ƒ์„ฑ
dataset.append({
"image": f"{class_name}/{filename}", # ์ƒ๋Œ€ ๊ฒฝ๋กœ ์ €์žฅ
"class": class_name, # ํด๋ž˜์Šค ๋ผ๋ฒจ
"captions": captions # ์บก์…˜ ๋ฆฌ์ŠคํŠธ
})
print(f"\n{class_name}/{filename}")
for i, c in enumerate(captions):
print(f"{i+1}. {c}")
# ----------------------
# 9. Stratified Split
# ----------------------
class_groups = defaultdict(list)
for item in dataset:
class_groups[item["class"]].append(item) # ํด๋ž˜์Šค๋ณ„ ๋ฌถ๊ธฐ
train_set, val_set, test_set = [], [], []
for class_name, items in class_groups.items():
random.shuffle(items) # ํด๋ž˜์Šค ๋‚ด๋ถ€ shuffle
total = len(items)
train_end = max(1, int(total * TRAIN_RATIO)) # ์ตœ์†Œ 1๊ฐœ ๋ณด์žฅ
val_end = train_end + max(1, int(total * VAL_RATIO))
train_set += items[:train_end]
val_set += items[train_end:val_end]
test_set += items[val_end:]
# split ๋ผ๋ฒจ ๋ถ€์—ฌ
for item in train_set:
item["split"] = "train"
for item in val_set:
item["split"] = "val"
for item in test_set:
item["split"] = "test"
dataset = train_set + val_set + test_set # ๋‹ค์‹œ ํ•˜๋‚˜๋กœ ํ•ฉ์นจ
# ----------------------
# 10. JSON ์ €์žฅ
# ----------------------
with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
json.dump(dataset, f, indent=4, ensure_ascii=False)
print(f"\n์™„๋ฃŒ: {OUTPUT_JSON} ์ƒ์„ฑ๋จ")