Mini-ImageNet / src /caption /generate_captions_florence2.py
ImAMJayKIM's picture
Upload 96 files
c1596ac verified
Raw
History Blame Contribute Delete
8.9 kB
import os
import json
import random
from pathlib import Path
import torch
from PIL import Image
from tqdm import tqdm
from dotenv import load_dotenv
from transformers import AutoProcessor, Florence2ForConditionalGeneration
# =========================================================
# 1. ์„ค์ •๊ฐ’
# =========================================================
# ์ „์ฒด ํด๋ž˜์Šค ์บก์…”๋‹: "data/raw"
# ํŠน์ • ํด๋ž˜์Šค๋งŒ ์บก์…”๋‹: "data/raw/apple"
INPUT_IMAGE_DIR = "data/raw"
# image ๊ฐ’์„ "pizza/hf_pizza_001.jpg" ํ˜•ํƒœ๋กœ ๋งŒ๋“ค๊ธฐ ์œ„ํ•œ ๊ธฐ์ค€ ๊ฒฝ๋กœ
DATA_RAW_ROOT = "data/raw"
# ๊ฒฐ๊ณผ JSON ์ €์žฅ ๊ฒฝ๋กœ
OUTPUT_JSON_PATH = "data/annotations/captions_flo_all.json"
# transformers 5.7.0์—์„œ๋Š” florence-community ๋ชจ๋ธ ์‚ฌ์šฉ ๊ถŒ์žฅ
# base-ft: ๊ฐ€๋ณ๊ณ  ๋‹ค์šด์ŠคํŠธ๋ฆผ task์— fine-tuning๋œ ๋ชจ๋ธ
# large-ft: ๋” ๋ฌด๊ฒ์ง€๋งŒ ํ’ˆ์งˆ์ด ๋” ์ข‹์„ ์ˆ˜ ์žˆ์Œ
MODEL_ID = "florence-community/Florence-2-base-ft"
# MODEL_ID = "florence-community/Florence-2-large-ft"
# .env ํŒŒ์ผ์—์„œ ์ฝ์„ Hugging Face ํ† ํฐ ์ด๋ฆ„
# ๊ณต๊ฐœ ๋ชจ๋ธ์ด๋ฉด ์—†์–ด๋„ ๋™์ž‘ํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ํ† ํฐ์„ ๋„ฃ์–ด๋‘๋Š” ํŽธ์ด ์•ˆ์ •์ ์ž…๋‹ˆ๋‹ค.
HF_TOKEN_ENV_NAME = "HF_TOKEN"
# split ๋น„์œจ: ๊ธฐ๋ณธ 7 : 1.5 : 1.5
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15
# split ์žฌํ˜„์„ ์œ„ํ•œ seed
RANDOM_SEED = 42
# ์ด๋ฏธ์ง€๋‹น ์บก์…˜ 3๊ฐœ ์ƒ์„ฑ
# Florence-2 ๋ฌธ์„œ์—์„œ ์ง€์›ํ•˜๋Š” caption task์ž…๋‹ˆ๋‹ค.
CAPTION_TASKS = [
"<CAPTION>",
"<DETAILED_CAPTION>",
"<MORE_DETAILED_CAPTION>",
]
# ์ƒ์„ฑ ์˜ต์…˜
NUM_BEAMS = 3
MAX_NEW_TOKENS = 64
# ๋ช‡ ์žฅ๋งˆ๋‹ค ์ค‘๊ฐ„ ์ €์žฅํ• ์ง€
SAVE_EVERY = 220
# ์ด๋ฏธ JSON์— ์žˆ๋Š” ์ด๋ฏธ์ง€๋Š” ๊ฑด๋„ˆ๋›ธ์ง€ ์—ฌ๋ถ€
SKIP_ALREADY_DONE = True
# ํ—ˆ์šฉ ์ด๋ฏธ์ง€ ํ™•์žฅ์ž
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp", ".bmp"]
# =========================================================
# 2. ์ด๋ฏธ์ง€ ๋ชฉ๋ก ๊ฐ€์ ธ์˜ค๊ธฐ
# =========================================================
def get_image_list():
input_dir = Path(INPUT_IMAGE_DIR).resolve()
raw_root = Path(DATA_RAW_ROOT).resolve()
if not input_dir.exists():
raise FileNotFoundError(f"์ž…๋ ฅ ๊ฒฝ๋กœ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค: {input_dir}")
image_list = []
for image_path in sorted(input_dir.rglob("*")):
if image_path.suffix.lower() not in IMAGE_EXTENSIONS:
continue
# ์˜ˆ:
# /workspace/data/raw/pizza/hf_pizza_001.jpg
# -> pizza/hf_pizza_001.jpg
relative_image_path = image_path.resolve().relative_to(raw_root).as_posix()
# ์˜ˆ:
# pizza/hf_pizza_001.jpg
# -> pizza
class_name = relative_image_path.split("/")[0]
image_list.append({
"path": image_path,
"image": relative_image_path,
"class": class_name,
})
return image_list
# =========================================================
# 3. train / val / test ๋‚˜๋ˆ„๊ธฐ
# =========================================================
def add_split(image_list):
random.seed(RANDOM_SEED)
total_ratio = TRAIN_RATIO + VAL_RATIO + TEST_RATIO
result = []
# ํด๋ž˜์Šค๋ณ„๋กœ ์ด๋ฏธ์ง€ ๋ชจ์œผ๊ธฐ
class_map = {}
for item in image_list:
class_name = item["class"]
if class_name not in class_map:
class_map[class_name] = []
class_map[class_name].append(item)
# ํด๋ž˜์Šค๋ณ„๋กœ train / val / test ๋‚˜๋ˆ„๊ธฐ
for class_name, items in class_map.items():
random.shuffle(items)
total_count = len(items)
train_count = round(total_count * TRAIN_RATIO / total_ratio)
val_count = round(total_count * VAL_RATIO / total_ratio)
for index, item in enumerate(items):
if index < train_count:
split = "train"
elif index < train_count + val_count:
split = "val"
else:
split = "test"
item["split"] = split
result.append(item)
return result
# =========================================================
# 4. Florence-2 ๋ชจ๋ธ ์ค€๋น„
# =========================================================
def load_model():
load_dotenv()
hf_token = os.getenv(HF_TOKEN_ENV_NAME)
if torch.cuda.is_available():
device = "cuda"
# GPU๊ฐ€ bfloat16์„ ์ง€์›ํ•˜๋ฉด bfloat16 ์‚ฌ์šฉ
# ์•„๋‹ˆ๋ฉด float16 ์‚ฌ์šฉ
if torch.cuda.is_bf16_supported():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16
else:
device = "cpu"
torch_dtype = torch.float32
print(f"device: {device}")
print(f"dtype: {torch_dtype}")
print(f"model: {MODEL_ID}")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=hf_token,
)
model = Florence2ForConditionalGeneration.from_pretrained(
MODEL_ID,
dtype=torch_dtype,
token=hf_token,
).to(device)
model.eval()
return model, processor, device, torch_dtype
# =========================================================
# 5. ์ด๋ฏธ์ง€ 1์žฅ ์บก์…”๋‹
# =========================================================
def make_caption(image, task, model, processor, device, torch_dtype):
inputs = processor(
text=task,
images=image,
return_tensors="pt",
)
inputs = inputs.to(device, torch_dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=NUM_BEAMS,
do_sample=False,
)
generated_text = processor.batch_decode(
generated_ids,
skip_special_tokens=False,
)[0]
parsed_result = processor.post_process_generation(
generated_text,
task=task,
image_size=image.size,
)
caption = parsed_result.get(task, "")
if not isinstance(caption, str):
caption = str(caption)
return caption.strip()
def make_three_captions(image_path, model, processor, device, torch_dtype):
image = Image.open(image_path).convert("RGB")
captions = []
for task in CAPTION_TASKS:
caption = make_caption(
image=image,
task=task,
model=model,
processor=processor,
device=device,
torch_dtype=torch_dtype,
)
captions.append(caption)
return captions
# =========================================================
# 6. ๊ธฐ์กด JSON ์ฝ๊ธฐ / ์ €์žฅํ•˜๊ธฐ
# =========================================================
def load_existing_result():
output_path = Path(OUTPUT_JSON_PATH)
if not output_path.exists():
return {}
with output_path.open("r", encoding="utf-8") as f:
data = json.load(f)
result = {}
for item in data:
result[item["image"]] = item
return result
def save_result(result_map):
output_path = Path(OUTPUT_JSON_PATH)
output_path.parent.mkdir(parents=True, exist_ok=True)
result_list = list(result_map.values())
result_list.sort(key=lambda x: x["image"])
with output_path.open("w", encoding="utf-8") as f:
json.dump(result_list, f, ensure_ascii=False, indent=4)
# =========================================================
# 7. ์‹คํ–‰
# =========================================================
def main():
print("์ด๋ฏธ์ง€ ๋ชฉ๋ก์„ ์ฝ๋Š” ์ค‘์ž…๋‹ˆ๋‹ค.")
image_list = get_image_list()
image_list = add_split(image_list)
print(f"์ด ์ด๋ฏธ์ง€ ์ˆ˜: {len(image_list)}")
result_map = load_existing_result()
model, processor, device, torch_dtype = load_model()
new_count = 0
skip_count = 0
fail_count = 0
for item in tqdm(image_list):
image_key = item["image"]
if SKIP_ALREADY_DONE and image_key in result_map:
skip_count += 1
continue
try:
captions = make_three_captions(
image_path=item["path"],
model=model,
processor=processor,
device=device,
torch_dtype=torch_dtype,
)
result_map[image_key] = {
"image": item["image"],
"class": item["class"],
"captions": captions,
"split": item["split"],
}
new_count += 1
if new_count % SAVE_EVERY == 0:
save_result(result_map)
except Exception as e:
fail_count += 1
print(f"\n์‹คํŒจํ•œ ์ด๋ฏธ์ง€: {item['path']}")
print(f"์—๋Ÿฌ ๋‚ด์šฉ: {e}")
save_result(result_map)
print("\n์บก์…”๋‹ ์™„๋ฃŒ")
print(f"์ƒˆ๋กœ ์ฒ˜๋ฆฌํ•œ ์ด๋ฏธ์ง€ ์ˆ˜: {new_count}")
print(f"๊ฑด๋„ˆ๋›ด ์ด๋ฏธ์ง€ ์ˆ˜: {skip_count}")
print(f"์‹คํŒจํ•œ ์ด๋ฏธ์ง€ ์ˆ˜: {fail_count}")
print(f"์ €์žฅ ์œ„์น˜: {OUTPUT_JSON_PATH}")
if __name__ == "__main__":
main()