Mini-ImageNet / src /caption /generate_captions_vit_gpt2.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
Raw
History Blame Contribute Delete
12.4 kB
import json
import random
from pathlib import Path
from collections import defaultdict
import torch
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
# =========================================================
# 1. ์„ค์ •๊ฐ’
# =========================================================
# ์ „์ฒด ํด๋ž˜์Šค ์บก์…”๋‹:
# INPUT_IMAGE_DIR = "/workspace/data/raw"
#
# ํŠน์ • ํด๋ž˜์Šค๋งŒ ์บก์…”๋‹:
# INPUT_IMAGE_DIR = "/workspace/data/raw/apple"
INPUT_IMAGE_DIR = "/workspace/data/raw/airplane"
OUTPUT_JSON_PATH = "/workspace/data/annotations/annotation.json"
MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning"
CAPTIONS_PER_IMAGE = 3
SPLIT_RATIO = {
"train": 0.7,
"val": 0.15,
"test": 0.15,
}
RANDOM_SEED = 42
BATCH_SIZE = 8
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp", ".bmp"]
# "auto": data/raw ์ž…๋ ฅ ์‹œ ์ „์ฒด ํด๋ž˜์Šค, data/raw/apple ์ž…๋ ฅ ์‹œ apple ํด๋ž˜์Šค๋งŒ ์ž๋™ ํŒ๋‹จ
# "raw": INPUT_IMAGE_DIR ์•„๋ž˜๋ฅผ ์ „์ฒด raw ํด๋”๋กœ ๊ฐ„์ฃผ
# "class": INPUT_IMAGE_DIR ์ž์ฒด๋ฅผ ํ•˜๋‚˜์˜ ํด๋ž˜์Šค ํด๋”๋กœ ๊ฐ„์ฃผ
INPUT_MODE = "auto"
# ์บก์…˜ ๋ฌธ์žฅ ๋์˜ ๋งˆ์นจํ‘œ ์ œ๊ฑฐ ์—ฌ๋ถ€
REMOVE_TRAILING_PERIOD = True
# beam search ์„ค์ •
GENERATION_CONFIG = {
"max_new_tokens": 32,
"num_beams": 8,
"num_return_sequences": CAPTIONS_PER_IMAGE,
"early_stopping": True,
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.1,
"length_penalty": 0.8,
}
# beam search ๊ฒฐ๊ณผ๊ฐ€ ์ค‘๋ณต๋  ๋•Œ ์ƒ˜ํ”Œ๋ง์œผ๋กœ ๋ณด์ถฉ
ENABLE_SAMPLING_FALLBACK = True
SAMPLING_FALLBACK_CONFIG = {
"max_new_tokens": 32,
"do_sample": True,
"top_p": 0.9,
"temperature": 0.8,
"num_return_sequences": CAPTIONS_PER_IMAGE * 2,
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.1,
}
MAX_FALLBACK_ROUNDS = 3
# ๊ทธ๋ž˜๋„ 3๊ฐœ๋ฅผ ๋ชป ์ฑ„์šฐ๋ฉด ์ค‘๋ณต์„ ํ—ˆ์šฉํ•ด์„œ๋ผ๋„ 3๊ฐœ๋ฅผ ๋งž์ถœ์ง€ ์—ฌ๋ถ€
FILL_WITH_DUPLICATES_IF_NEEDED = True
# =========================================================
# 2. ๊ธฐ๋ณธ ์œ ํ‹ธ ํ•จ์ˆ˜
# =========================================================
def validate_config():
total_ratio = sum(SPLIT_RATIO.values())
if abs(total_ratio - 1.0) > 1e-6:
raise ValueError(f"SPLIT_RATIO์˜ ํ•ฉ์€ 1์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ํ˜„์žฌ ํ•ฉ: {total_ratio}")
if GENERATION_CONFIG["num_beams"] < CAPTIONS_PER_IMAGE:
raise ValueError("num_beams๋Š” CAPTIONS_PER_IMAGE๋ณด๋‹ค ํฌ๊ฑฐ๋‚˜ ๊ฐ™์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
if GENERATION_CONFIG["num_return_sequences"] != CAPTIONS_PER_IMAGE:
raise ValueError("GENERATION_CONFIG์˜ num_return_sequences๋Š” CAPTIONS_PER_IMAGE์™€ ๊ฐ™์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
def is_image_file(path: Path) -> bool:
return path.suffix.lower() in IMAGE_EXTENSIONS
def clean_caption(text: str) -> str:
caption = " ".join(text.strip().split())
if REMOVE_TRAILING_PERIOD:
caption = caption.rstrip(".")
return caption
def unique_captions(captions):
result = []
seen = set()
for caption in captions:
caption = clean_caption(caption)
key = caption.lower()
if caption and key not in seen:
result.append(caption)
seen.add(key)
return result
def load_image(image_path: Path):
try:
return Image.open(image_path).convert("RGB")
except (UnidentifiedImageError, OSError) as e:
print(f"[SKIP] ์ด๋ฏธ์ง€๋ฅผ ์—ด ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {image_path} / error: {e}")
return None
# =========================================================
# 3. ์ด๋ฏธ์ง€ ๋ชฉ๋ก ์ˆ˜์ง‘
# =========================================================
def has_direct_images(input_dir: Path) -> bool:
for child in input_dir.iterdir():
if child.is_file() and is_image_file(child):
return True
return False
def get_relative_base_dir(input_dir: Path) -> Path:
"""
JSON์˜ image ๊ฐ’์„ 'ํด๋ž˜์Šคํด๋”/์ด๋ฏธ์ง€๋ช…' ํ˜•ํƒœ๋กœ ๋งŒ๋“ค๊ธฐ ์œ„ํ•œ ๊ธฐ์ค€ ๊ฒฝ๋กœ๋ฅผ ์ •ํ•œ๋‹ค.
์˜ˆ์‹œ 1)
INPUT_IMAGE_DIR = /workspace/data/raw
image file = /workspace/data/raw/pizza/hf_pizza_001.jpg
relative base = /workspace/data/raw
result = pizza/hf_pizza_001.jpg
์˜ˆ์‹œ 2)
INPUT_IMAGE_DIR = /workspace/data/raw/apple
image file = /workspace/data/raw/apple/hf_apple_001.jpg
relative base = /workspace/data/raw
result = apple/hf_apple_001.jpg
"""
if INPUT_MODE == "raw":
return input_dir
if INPUT_MODE == "class":
return input_dir.parent
if INPUT_MODE == "auto":
if has_direct_images(input_dir):
return input_dir.parent
return input_dir
raise ValueError("INPUT_MODE์€ 'auto', 'raw', 'class' ์ค‘ ํ•˜๋‚˜์—ฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
def collect_image_records(input_dir: str):
input_path = Path(input_dir)
if not input_path.exists():
raise FileNotFoundError(f"์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค: {input_path}")
relative_base_dir = get_relative_base_dir(input_path)
records = []
for image_path in sorted(input_path.rglob("*")):
if not image_path.is_file():
continue
if not is_image_file(image_path):
continue
relative_path = image_path.relative_to(relative_base_dir)
relative_path_str = relative_path.as_posix()
# image ๊ฐ’์ด apple/xxx.jpg ๋ผ๋ฉด class๋Š” apple
class_name = relative_path.parts[0]
records.append({
"path": image_path,
"image": relative_path_str,
"class": class_name,
})
if not records:
raise ValueError(f"์บก์…”๋‹ํ•  ์ด๋ฏธ์ง€๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค: {input_path}")
return records
# =========================================================
# 4. train / val / test split ๋ฐฐ์ •
# =========================================================
def assign_split(records):
random.seed(RANDOM_SEED)
class_map = defaultdict(list)
for record in records:
class_map[record["class"]].append(record)
result = []
for class_name, items in class_map.items():
random.shuffle(items)
total = len(items)
train_count = int(total * SPLIT_RATIO["train"])
val_count = int(total * SPLIT_RATIO["val"])
for idx, item in enumerate(items):
if idx < train_count:
item["split"] = "train"
elif idx < train_count + val_count:
item["split"] = "val"
else:
item["split"] = "test"
result.append(item)
result.sort(key=lambda x: x["image"])
return result
# =========================================================
# 5. ๋ชจ๋ธ ๋กœ๋“œ
# =========================================================
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] device: {device}")
print(f"[INFO] model: {MODEL_NAME}")
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id
model.to(device)
model.eval()
return model, processor, tokenizer, device
# =========================================================
# 6. ์บก์…˜ ์ƒ์„ฑ
# =========================================================
def decode_output_ids(output_ids, tokenizer):
captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return [clean_caption(caption) for caption in captions]
@torch.no_grad()
def generate_by_beam_search(images, model, processor, tokenizer, device):
pixel_values = processor(
images=images,
return_tensors="pt"
).pixel_values.to(device)
output_ids = model.generate(
pixel_values,
**GENERATION_CONFIG
)
captions = decode_output_ids(output_ids, tokenizer)
grouped = []
start = 0
for _ in images:
end = start + CAPTIONS_PER_IMAGE
grouped.append(captions[start:end])
start = end
return grouped
@torch.no_grad()
def generate_by_sampling(image, model, processor, tokenizer, device):
pixel_values = processor(
images=[image],
return_tensors="pt"
).pixel_values.to(device)
output_ids = model.generate(
pixel_values,
**SAMPLING_FALLBACK_CONFIG
)
return decode_output_ids(output_ids, tokenizer)
def complete_caption_count(captions, original_candidates):
"""
๊ธฐ๋ณธ ๋ชฉํ‘œ:
- ์ตœ๋Œ€ํ•œ ์ค‘๋ณต ์—†๋Š” ์บก์…˜ 3๊ฐœ๋ฅผ ๋งŒ๋“ ๋‹ค.
๋‹จ, ๋ชจ๋ธ์ด ๋น„์Šทํ•œ ๋ฌธ์žฅ๋งŒ ๊ณ„์† ๋งŒ๋“ค๋ฉด 3๊ฐœ๋ฅผ ๋ชป ์ฑ„์šธ ์ˆ˜ ์žˆ๋‹ค.
์ด๋•Œ FILL_WITH_DUPLICATES_IF_NEEDED=True์ด๋ฉด ์ค‘๋ณต์„ ํ—ˆ์šฉํ•ด์„œ 3๊ฐœ๋ฅผ ๋งž์ถ˜๋‹ค.
"""
captions = unique_captions(captions)
if len(captions) >= CAPTIONS_PER_IMAGE:
return captions[:CAPTIONS_PER_IMAGE]
if not FILL_WITH_DUPLICATES_IF_NEEDED:
return captions
for caption in original_candidates:
caption = clean_caption(caption)
if caption:
captions.append(caption)
if len(captions) >= CAPTIONS_PER_IMAGE:
break
return captions[:CAPTIONS_PER_IMAGE]
def generate_captions_for_batch(batch_records, model, processor, tokenizer, device):
images = []
valid_records = []
for record in batch_records:
image = load_image(record["path"])
if image is None:
continue
images.append(image)
valid_records.append(record)
if not images:
return []
beam_caption_groups = generate_by_beam_search(
images=images,
model=model,
processor=processor,
tokenizer=tokenizer,
device=device
)
results = []
for record, image, beam_captions in zip(valid_records, images, beam_caption_groups):
all_candidates = list(beam_captions)
captions = unique_captions(beam_captions)
if ENABLE_SAMPLING_FALLBACK:
fallback_round = 0
while len(captions) < CAPTIONS_PER_IMAGE and fallback_round < MAX_FALLBACK_ROUNDS:
sampled_captions = generate_by_sampling(
image=image,
model=model,
processor=processor,
tokenizer=tokenizer,
device=device
)
all_candidates.extend(sampled_captions)
captions = unique_captions(captions + sampled_captions)
fallback_round += 1
captions = complete_caption_count(
captions=captions,
original_candidates=all_candidates
)
results.append({
"image": record["image"],
"class": record["class"],
"captions": captions,
"split": record["split"],
})
return results
# =========================================================
# 7. JSON ์ €์žฅ
# =========================================================
def save_json(data, output_path: str):
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print(f"[DONE] JSON ์ €์žฅ ์™„๋ฃŒ: {output_path}")
print(f"[DONE] ์ด ์ด๋ฏธ์ง€ ์ˆ˜: {len(data)}")
# =========================================================
# 8. ์‹คํ–‰
# =========================================================
def main():
validate_config()
records = collect_image_records(INPUT_IMAGE_DIR)
records = assign_split(records)
print(f"[INFO] ์บก์…”๋‹ ๋Œ€์ƒ ์ด๋ฏธ์ง€ ์ˆ˜: {len(records)}")
model, processor, tokenizer, device = load_model()
results = []
for start in tqdm(range(0, len(records), BATCH_SIZE), desc="captioning"):
end = start + BATCH_SIZE
batch_records = records[start:end]
batch_results = generate_captions_for_batch(
batch_records=batch_records,
model=model,
processor=processor,
tokenizer=tokenizer,
device=device
)
results.extend(batch_results)
save_json(results, OUTPUT_JSON_PATH)
if __name__ == "__main__":
main()