Mini-ImageNet / src /caption /generate_captions_git.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
Raw
History Blame Contribute Delete
16.1 kB
import json
import math
import random
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple
import os
import torch
from dotenv import load_dotenv
from PIL import Image
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoProcessor
# ============================================================
# 1. 설정값
# ============================================================
load_dotenv()
# .env 안의 HF_TOKEN 읽기
hf_token = os.getenv("HF_TOKEN")
# 이미지 원본 루트 경로
# 예:
# - 전체 클래스 캡셔닝: "data/raw"
# - 특정 클래스만 캡셔닝: "data/raw/apple"
INPUT_IMAGE_PATH = "data/raw/airplane"
# image 필드를 만들 때 기준이 되는 root
# JSON에는 "pizza/hf_pizza_001.jpg" 형태로 저장됨
DATA_RAW_ROOT = "data/raw"
# 결과 저장 경로
OUTPUT_JSON_PATH = "data/annotations/captions_git.json"
# 에러 이미지 목록 저장 경로
ERROR_JSON_PATH = "data/annotations/caption_git_errors.json"
# GIT 모델
# 기본 추천: microsoft/git-base-coco
# 더 큰 모델을 쓰고 싶으면: microsoft/git-large-coco
MODEL_NAME = "microsoft/git-large-coco"
# 이미지당 생성할 캡션 개수
CAPTIONS_PER_IMAGE = 3
# split 비율
# 기본 7 : 1.5 : 1.5
SPLIT_RATIO = {
"train": 0.7,
"val": 0.15,
"test": 0.15,
}
# split 재현성을 위한 seed
RANDOM_SEED = 42
# 추론 배치 크기
# GPU 메모리가 부족하면 8 -> 4 -> 2 -> 1 순서로 줄이기
BATCH_SIZE = 8
# 장치 설정
# "auto": CUDA 가능하면 GPU, 아니면 CPU
# 직접 지정 가능: "cuda", "cpu"
DEVICE = "auto"
# dtype 설정
# "auto": CUDA면 float16, CPU면 float32
# 직접 지정 가능: "float32", "float16", "bfloat16"
TORCH_DTYPE = "auto"
# 중간 저장 간격
# 이미지가 많을 때 중간에 오류가 나도 일부 결과를 보존하기 위한 설정
SAVE_EVERY_N_IMAGES = 100
# 기존 OUTPUT_JSON_PATH가 있으면 이미 캡셔닝된 이미지는 건너뛸지 여부
RESUME_FROM_EXISTING_JSON = True
# 지원 이미지 확장자
SUPPORTED_EXTENSIONS = {
".jpg", ".jpeg", ".png", ".webp", ".bmp"
}
# 캡션 생성 설정
# num_beams >= num_return_sequences 여야 함
# num_beam_groups를 사용해서 서로 조금 다른 caption을 생성하도록 함
GENERATION_CONFIG = {
"max_length": 40,
"num_beams": 5,
"num_return_sequences": CAPTIONS_PER_IMAGE,
"early_stopping": True,
"no_repeat_ngram_size": 2,
}
# beam search 결과가 너무 중복될 때 추가 샘플링으로 보완할지 여부
ENABLE_SAMPLING_FALLBACK = True
SAMPLING_FALLBACK_CONFIG = {
"max_length": 40,
"do_sample": True,
"top_p": 0.9,
"temperature": 0.8,
"num_return_sequences": CAPTIONS_PER_IMAGE,
"no_repeat_ngram_size": 2,
}
# ============================================================
# 2. 데이터 구조
# ============================================================
@dataclass
class ImageItem:
path: Path
image_field: str
class_name: str
split: str = ""
# ============================================================
# 3. 유틸 함수
# ============================================================
def resolve_device() -> torch.device:
if DEVICE == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(DEVICE)
def resolve_dtype(device: torch.device) -> torch.dtype:
if TORCH_DTYPE == "auto":
return torch.float16 if device.type == "cuda" else torch.float32
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
if TORCH_DTYPE not in dtype_map:
raise ValueError(f"지원하지 않는 TORCH_DTYPE입니다: {TORCH_DTYPE}")
if device.type == "cpu" and TORCH_DTYPE in {"float16", "bfloat16"}:
print("[WARN] CPU에서는 float16/bfloat16이 불안정할 수 있어 float32로 변경합니다.")
return torch.float32
return dtype_map[TORCH_DTYPE]
def normalize_caption(text: str) -> str:
text = text.strip()
text = re.sub(r"\s+", " ", text)
return text
def deduplicate_captions(captions: List[str]) -> List[str]:
result = []
seen = set()
for caption in captions:
caption = normalize_caption(caption)
if not caption:
continue
key = caption.lower()
if key in seen:
continue
seen.add(key)
result.append(caption)
return result
def ensure_caption_count(captions: List[str], target_count: int) -> List[str]:
captions = deduplicate_captions(captions)
if len(captions) >= target_count:
return captions[:target_count]
if len(captions) == 0:
return [""] * target_count
while len(captions) < target_count:
captions.append(captions[-1])
return captions
def save_json(path: Path, data: List[dict]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def load_existing_json(path: Path) -> Dict[str, dict]:
if not path.exists():
return {}
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
result = {}
for item in data:
image_key = item.get("image")
if image_key:
result[image_key] = item
return result
# ============================================================
# 4. 이미지 수집
# ============================================================
def collect_images(input_path: Path, data_raw_root: Path) -> List[ImageItem]:
if not input_path.exists():
raise FileNotFoundError(f"입력 경로가 존재하지 않습니다: {input_path}")
image_paths = sorted([
path
for path in input_path.rglob("*")
if path.is_file() and path.suffix.lower() in SUPPORTED_EXTENSIONS
])
if not image_paths:
raise RuntimeError(f"이미지를 찾지 못했습니다: {input_path}")
items = []
for image_path in image_paths:
try:
relative_path = image_path.relative_to(data_raw_root)
except ValueError:
raise ValueError(
f"이미지 경로가 DATA_RAW_ROOT 하위에 있어야 합니다.\n"
f"image_path={image_path}\n"
f"DATA_RAW_ROOT={data_raw_root}"
)
if len(relative_path.parts) < 2:
raise ValueError(
f"이미지는 클래스 폴더 하위에 있어야 합니다: {image_path}\n"
f"예: data/raw/pizza/hf_pizza_001.jpg"
)
class_name = relative_path.parts[0]
image_field = relative_path.as_posix()
items.append(
ImageItem(
path=image_path,
image_field=image_field,
class_name=class_name,
)
)
return items
# ============================================================
# 5. split 분리
# ============================================================
def calculate_split_counts(total_count: int) -> Dict[str, int]:
ratio_sum = sum(SPLIT_RATIO.values())
raw_counts = {
split_name: total_count * ratio / ratio_sum
for split_name, ratio in SPLIT_RATIO.items()
}
counts = {
split_name: int(math.floor(count))
for split_name, count in raw_counts.items()
}
remaining = total_count - sum(counts.values())
# 소수점이 큰 split부터 남은 개수 배분
sorted_splits = sorted(
raw_counts.keys(),
key=lambda split_name: raw_counts[split_name] - counts[split_name],
reverse=True,
)
for split_name in sorted_splits[:remaining]:
counts[split_name] += 1
return counts
def assign_splits(items: List[ImageItem]) -> List[ImageItem]:
rng = random.Random(RANDOM_SEED)
class_map = defaultdict(list)
for item in items:
class_map[item.class_name].append(item)
for class_name, class_items in class_map.items():
rng.shuffle(class_items)
counts = calculate_split_counts(len(class_items))
start = 0
for split_name in ["train", "val", "test"]:
end = start + counts.get(split_name, 0)
for item in class_items[start:end]:
item.split = split_name
start = end
return items
# ============================================================
# 6. 모델 로드
# ============================================================
def load_model():
device = resolve_device()
torch_dtype = resolve_dtype(device)
print(f"[INFO] device={device}")
print(f"[INFO] dtype={torch_dtype}")
print(f"[INFO] model={MODEL_NAME}")
processor = AutoProcessor.from_pretrained(MODEL_NAME, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
dtype=torch_dtype,
token=hf_token
)
model.to(device)
model.eval()
return model, processor, device, torch_dtype
# ============================================================
# 7. 캡셔닝
# ============================================================
def load_batch_images(batch_items: List[ImageItem]) -> Tuple[List[Image.Image], List[ImageItem], List[dict]]:
images = []
valid_items = []
errors = []
for item in batch_items:
try:
with Image.open(item.path) as img:
images.append(img.convert("RGB"))
valid_items.append(item)
except Exception as e:
errors.append({
"image": item.image_field,
"class": item.class_name,
"split": item.split,
"error": str(e),
})
return images, valid_items, errors
@torch.inference_mode()
def generate_batch_captions(
model,
processor,
device: torch.device,
torch_dtype: torch.dtype,
images: List[Image.Image],
) -> List[List[str]]:
inputs = processor(images=images, return_tensors="pt")
inputs = {
key: value.to(device)
for key, value in inputs.items()
}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=torch_dtype)
generated_ids = model.generate(
**inputs,
**GENERATION_CONFIG,
)
decoded = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)
grouped_captions = []
for i in range(len(images)):
start = i * CAPTIONS_PER_IMAGE
end = start + CAPTIONS_PER_IMAGE
captions = decoded[start:end]
captions = deduplicate_captions(captions)
grouped_captions.append(captions)
return grouped_captions
@torch.inference_mode()
def generate_sampling_fallback_captions(
model,
processor,
device: torch.device,
torch_dtype: torch.dtype,
image: Image.Image,
) -> List[str]:
inputs = processor(images=[image], return_tensors="pt")
inputs = {
key: value.to(device)
for key, value in inputs.items()
}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=torch_dtype)
generated_ids = model.generate(
**inputs,
**SAMPLING_FALLBACK_CONFIG,
)
decoded = processor.batch_decode(
generated_ids,
skip_special_tokens=True,
)
return deduplicate_captions(decoded)
def make_result_item(item: ImageItem, captions: List[str]) -> dict:
return {
"image": item.image_field,
"class": item.class_name,
"captions": captions,
"split": item.split,
}
def caption_images(
model,
processor,
device: torch.device,
torch_dtype: torch.dtype,
items: List[ImageItem],
existing_result_map: Dict[str, dict],
) -> Tuple[Dict[str, dict], List[dict]]:
result_map = dict(existing_result_map)
error_list = []
target_items = [
item
for item in items
if item.image_field not in result_map
]
print(f"[INFO] 전체 이미지 수: {len(items)}")
print(f"[INFO] 기존 결과 수: {len(existing_result_map)}")
print(f"[INFO] 새로 캡셔닝할 이미지 수: {len(target_items)}")
processed_count = 0
for batch_start in tqdm(range(0, len(target_items), BATCH_SIZE), desc="Captioning"):
batch_items = target_items[batch_start:batch_start + BATCH_SIZE]
images, valid_items, errors = load_batch_images(batch_items)
error_list.extend(errors)
if not images:
continue
try:
batch_captions = generate_batch_captions(
model=model,
processor=processor,
device=device,
torch_dtype=torch_dtype,
images=images,
)
for image, item, captions in zip(images, valid_items, batch_captions):
if ENABLE_SAMPLING_FALLBACK and len(captions) < CAPTIONS_PER_IMAGE:
fallback_captions = generate_sampling_fallback_captions(
model=model,
processor=processor,
device=device,
torch_dtype=torch_dtype,
image=image,
)
captions = deduplicate_captions(captions + fallback_captions)
captions = ensure_caption_count(
captions=captions,
target_count=CAPTIONS_PER_IMAGE,
)
result_map[item.image_field] = make_result_item(
item=item,
captions=captions,
)
processed_count += 1
except Exception as e:
print("[ERROR] 배치 캡셔닝 실패")
print(f"[ERROR] {type(e).__name__}: {e}")
for item in valid_items:
error_list.append({
"image": item.image_field,
"class": item.class_name,
"split": item.split,
"error": str(e),
})
if SAVE_EVERY_N_IMAGES > 0 and processed_count > 0:
if processed_count % SAVE_EVERY_N_IMAGES == 0:
current_results = [
result_map[item.image_field]
for item in items
if item.image_field in result_map
]
save_json(Path(OUTPUT_JSON_PATH), current_results)
save_json(Path(ERROR_JSON_PATH), error_list)
return result_map, error_list
# ============================================================
# 8. main
# ============================================================
def main():
input_path = Path(INPUT_IMAGE_PATH).resolve()
data_raw_root = Path(DATA_RAW_ROOT).resolve()
output_json_path = Path(OUTPUT_JSON_PATH)
error_json_path = Path(ERROR_JSON_PATH)
items = collect_images(
input_path=input_path,
data_raw_root=data_raw_root,
)
items = assign_splits(items)
existing_result_map = {}
if RESUME_FROM_EXISTING_JSON:
existing_result_map = load_existing_json(output_json_path)
model, processor, device, torch_dtype = load_model()
result_map, error_list = caption_images(
model=model,
processor=processor,
device=device,
torch_dtype=torch_dtype,
items=items,
existing_result_map=existing_result_map,
)
final_results = [
result_map[item.image_field]
for item in items
if item.image_field in result_map
]
save_json(output_json_path, final_results)
save_json(error_json_path, error_list)
print("[DONE] 캡셔닝 완료")
print(f"[DONE] 결과 저장: {output_json_path}")
print(f"[DONE] 에러 저장: {error_json_path}")
print(f"[DONE] 정상 결과 수: {len(final_results)}")
print(f"[DONE] 에러 수: {len(error_list)}")
if __name__ == "__main__":
main()