import os
import json
import random
from pathlib import Path
import torch
from PIL import Image
from tqdm import tqdm
from dotenv import load_dotenv
from transformers import AutoProcessor, Florence2ForConditionalGeneration
# =========================================================
# 1. 설정값
# =========================================================
# 전체 클래스 캡셔닝: "data/raw"
# 특정 클래스만 캡셔닝: "data/raw/apple"
INPUT_IMAGE_DIR = "data/raw"
# image 값을 "pizza/hf_pizza_001.jpg" 형태로 만들기 위한 기준 경로
DATA_RAW_ROOT = "data/raw"
# 결과 JSON 저장 경로
OUTPUT_JSON_PATH = "data/annotations/captions_flo_all.json"
# transformers 5.7.0에서는 florence-community 모델 사용 권장
# base-ft: 가볍고 다운스트림 task에 fine-tuning된 모델
# large-ft: 더 무겁지만 품질이 더 좋을 수 있음
MODEL_ID = "florence-community/Florence-2-base-ft"
# MODEL_ID = "florence-community/Florence-2-large-ft"
# .env 파일에서 읽을 Hugging Face 토큰 이름
# 공개 모델이면 없어도 동작할 수 있지만, 토큰을 넣어두는 편이 안정적입니다.
HF_TOKEN_ENV_NAME = "HF_TOKEN"
# split 비율: 기본 7 : 1.5 : 1.5
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15
# split 재현을 위한 seed
RANDOM_SEED = 42
# 이미지당 캡션 3개 생성
# Florence-2 문서에서 지원하는 caption task입니다.
CAPTION_TASKS = [
"
",
"",
"",
]
# 생성 옵션
NUM_BEAMS = 3
MAX_NEW_TOKENS = 64
# 몇 장마다 중간 저장할지
SAVE_EVERY = 220
# 이미 JSON에 있는 이미지는 건너뛸지 여부
SKIP_ALREADY_DONE = True
# 허용 이미지 확장자
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp", ".bmp"]
# =========================================================
# 2. 이미지 목록 가져오기
# =========================================================
def get_image_list():
input_dir = Path(INPUT_IMAGE_DIR).resolve()
raw_root = Path(DATA_RAW_ROOT).resolve()
if not input_dir.exists():
raise FileNotFoundError(f"입력 경로가 없습니다: {input_dir}")
image_list = []
for image_path in sorted(input_dir.rglob("*")):
if image_path.suffix.lower() not in IMAGE_EXTENSIONS:
continue
# 예:
# /workspace/data/raw/pizza/hf_pizza_001.jpg
# -> pizza/hf_pizza_001.jpg
relative_image_path = image_path.resolve().relative_to(raw_root).as_posix()
# 예:
# pizza/hf_pizza_001.jpg
# -> pizza
class_name = relative_image_path.split("/")[0]
image_list.append({
"path": image_path,
"image": relative_image_path,
"class": class_name,
})
return image_list
# =========================================================
# 3. train / val / test 나누기
# =========================================================
def add_split(image_list):
random.seed(RANDOM_SEED)
total_ratio = TRAIN_RATIO + VAL_RATIO + TEST_RATIO
result = []
# 클래스별로 이미지 모으기
class_map = {}
for item in image_list:
class_name = item["class"]
if class_name not in class_map:
class_map[class_name] = []
class_map[class_name].append(item)
# 클래스별로 train / val / test 나누기
for class_name, items in class_map.items():
random.shuffle(items)
total_count = len(items)
train_count = round(total_count * TRAIN_RATIO / total_ratio)
val_count = round(total_count * VAL_RATIO / total_ratio)
for index, item in enumerate(items):
if index < train_count:
split = "train"
elif index < train_count + val_count:
split = "val"
else:
split = "test"
item["split"] = split
result.append(item)
return result
# =========================================================
# 4. Florence-2 모델 준비
# =========================================================
def load_model():
load_dotenv()
hf_token = os.getenv(HF_TOKEN_ENV_NAME)
if torch.cuda.is_available():
device = "cuda"
# GPU가 bfloat16을 지원하면 bfloat16 사용
# 아니면 float16 사용
if torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16
else:
device = "cpu"
torch_dtype = torch.float32
print(f"device: {device}")
print(f"dtype: {torch_dtype}")
print(f"model: {MODEL_ID}")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=hf_token,
)
model = Florence2ForConditionalGeneration.from_pretrained(
MODEL_ID,
dtype=torch_dtype,
token=hf_token,
).to(device)
model.eval()
return model, processor, device, torch_dtype
# =========================================================
# 5. 이미지 1장 캡셔닝
# =========================================================
def make_caption(image, task, model, processor, device, torch_dtype):
inputs = processor(
text=task,
images=image,
return_tensors="pt",
)
inputs = inputs.to(device, torch_dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=NUM_BEAMS,
do_sample=False,
)
generated_text = processor.batch_decode(
generated_ids,
skip_special_tokens=False,
)[0]
parsed_result = processor.post_process_generation(
generated_text,
task=task,
image_size=image.size,
)
caption = parsed_result.get(task, "")
if not isinstance(caption, str):
caption = str(caption)
return caption.strip()
def make_three_captions(image_path, model, processor, device, torch_dtype):
image = Image.open(image_path).convert("RGB")
captions = []
for task in CAPTION_TASKS:
caption = make_caption(
image=image,
task=task,
model=model,
processor=processor,
device=device,
torch_dtype=torch_dtype,
)
captions.append(caption)
return captions
# =========================================================
# 6. 기존 JSON 읽기 / 저장하기
# =========================================================
def load_existing_result():
output_path = Path(OUTPUT_JSON_PATH)
if not output_path.exists():
return {}
with output_path.open("r", encoding="utf-8") as f:
data = json.load(f)
result = {}
for item in data:
result[item["image"]] = item
return result
def save_result(result_map):
output_path = Path(OUTPUT_JSON_PATH)
output_path.parent.mkdir(parents=True, exist_ok=True)
result_list = list(result_map.values())
result_list.sort(key=lambda x: x["image"])
with output_path.open("w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=4)
# =========================================================
# 7. 실행
# =========================================================
def main():
print("이미지 목록을 읽는 중입니다.")
image_list = get_image_list()
image_list = add_split(image_list)
print(f"총 이미지 수: {len(image_list)}")
result_map = load_existing_result()
model, processor, device, torch_dtype = load_model()
new_count = 0
skip_count = 0
fail_count = 0
for item in tqdm(image_list):
image_key = item["image"]
if SKIP_ALREADY_DONE and image_key in result_map:
skip_count += 1
continue
try:
captions = make_three_captions(
image_path=item["path"],
model=model,
processor=processor,
device=device,
torch_dtype=torch_dtype,
)
result_map[image_key] = {
"image": item["image"],
"class": item["class"],
"captions": captions,
"split": item["split"],
}
new_count += 1
if new_count % SAVE_EVERY == 0:
save_result(result_map)
except Exception as e:
fail_count += 1
print(f"\n실패한 이미지: {item['path']}")
print(f"에러 내용: {e}")
save_result(result_map)
print("\n캡셔닝 완료")
print(f"새로 처리한 이미지 수: {new_count}")
print(f"건너뛴 이미지 수: {skip_count}")
print(f"실패한 이미지 수: {fail_count}")
print(f"저장 위치: {OUTPUT_JSON_PATH}")
if __name__ == "__main__":
main()