Spaces:
Sleeping
Sleeping
| import json | |
| import random | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import torch | |
| from PIL import Image, UnidentifiedImageError | |
| from tqdm import tqdm | |
| from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
| # ========================================================= | |
| # 1. ์ค์ ๊ฐ | |
| # ========================================================= | |
| # ์ ์ฒด ํด๋์ค ์บก์ ๋: | |
| # INPUT_IMAGE_DIR = "/workspace/data/raw" | |
| # | |
| # ํน์ ํด๋์ค๋ง ์บก์ ๋: | |
| # INPUT_IMAGE_DIR = "/workspace/data/raw/apple" | |
| INPUT_IMAGE_DIR = "/workspace/data/raw/airplane" | |
| OUTPUT_JSON_PATH = "/workspace/data/annotations/annotation.json" | |
| MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning" | |
| CAPTIONS_PER_IMAGE = 3 | |
| SPLIT_RATIO = { | |
| "train": 0.7, | |
| "val": 0.15, | |
| "test": 0.15, | |
| } | |
| RANDOM_SEED = 42 | |
| BATCH_SIZE = 8 | |
| IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp", ".bmp"] | |
| # "auto": data/raw ์ ๋ ฅ ์ ์ ์ฒด ํด๋์ค, data/raw/apple ์ ๋ ฅ ์ apple ํด๋์ค๋ง ์๋ ํ๋จ | |
| # "raw": INPUT_IMAGE_DIR ์๋๋ฅผ ์ ์ฒด raw ํด๋๋ก ๊ฐ์ฃผ | |
| # "class": INPUT_IMAGE_DIR ์์ฒด๋ฅผ ํ๋์ ํด๋์ค ํด๋๋ก ๊ฐ์ฃผ | |
| INPUT_MODE = "auto" | |
| # ์บก์ ๋ฌธ์ฅ ๋์ ๋ง์นจํ ์ ๊ฑฐ ์ฌ๋ถ | |
| REMOVE_TRAILING_PERIOD = True | |
| # beam search ์ค์ | |
| GENERATION_CONFIG = { | |
| "max_new_tokens": 32, | |
| "num_beams": 8, | |
| "num_return_sequences": CAPTIONS_PER_IMAGE, | |
| "early_stopping": True, | |
| "no_repeat_ngram_size": 2, | |
| "repetition_penalty": 1.1, | |
| "length_penalty": 0.8, | |
| } | |
| # beam search ๊ฒฐ๊ณผ๊ฐ ์ค๋ณต๋ ๋ ์ํ๋ง์ผ๋ก ๋ณด์ถฉ | |
| ENABLE_SAMPLING_FALLBACK = True | |
| SAMPLING_FALLBACK_CONFIG = { | |
| "max_new_tokens": 32, | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| "temperature": 0.8, | |
| "num_return_sequences": CAPTIONS_PER_IMAGE * 2, | |
| "no_repeat_ngram_size": 2, | |
| "repetition_penalty": 1.1, | |
| } | |
| MAX_FALLBACK_ROUNDS = 3 | |
| # ๊ทธ๋๋ 3๊ฐ๋ฅผ ๋ชป ์ฑ์ฐ๋ฉด ์ค๋ณต์ ํ์ฉํด์๋ผ๋ 3๊ฐ๋ฅผ ๋ง์ถ์ง ์ฌ๋ถ | |
| FILL_WITH_DUPLICATES_IF_NEEDED = True | |
| # ========================================================= | |
| # 2. ๊ธฐ๋ณธ ์ ํธ ํจ์ | |
| # ========================================================= | |
| def validate_config(): | |
| total_ratio = sum(SPLIT_RATIO.values()) | |
| if abs(total_ratio - 1.0) > 1e-6: | |
| raise ValueError(f"SPLIT_RATIO์ ํฉ์ 1์ด์ด์ผ ํฉ๋๋ค. ํ์ฌ ํฉ: {total_ratio}") | |
| if GENERATION_CONFIG["num_beams"] < CAPTIONS_PER_IMAGE: | |
| raise ValueError("num_beams๋ CAPTIONS_PER_IMAGE๋ณด๋ค ํฌ๊ฑฐ๋ ๊ฐ์์ผ ํฉ๋๋ค.") | |
| if GENERATION_CONFIG["num_return_sequences"] != CAPTIONS_PER_IMAGE: | |
| raise ValueError("GENERATION_CONFIG์ num_return_sequences๋ CAPTIONS_PER_IMAGE์ ๊ฐ์์ผ ํฉ๋๋ค.") | |
| def is_image_file(path: Path) -> bool: | |
| return path.suffix.lower() in IMAGE_EXTENSIONS | |
| def clean_caption(text: str) -> str: | |
| caption = " ".join(text.strip().split()) | |
| if REMOVE_TRAILING_PERIOD: | |
| caption = caption.rstrip(".") | |
| return caption | |
| def unique_captions(captions): | |
| result = [] | |
| seen = set() | |
| for caption in captions: | |
| caption = clean_caption(caption) | |
| key = caption.lower() | |
| if caption and key not in seen: | |
| result.append(caption) | |
| seen.add(key) | |
| return result | |
| def load_image(image_path: Path): | |
| try: | |
| return Image.open(image_path).convert("RGB") | |
| except (UnidentifiedImageError, OSError) as e: | |
| print(f"[SKIP] ์ด๋ฏธ์ง๋ฅผ ์ด ์ ์์ต๋๋ค: {image_path} / error: {e}") | |
| return None | |
| # ========================================================= | |
| # 3. ์ด๋ฏธ์ง ๋ชฉ๋ก ์์ง | |
| # ========================================================= | |
| def has_direct_images(input_dir: Path) -> bool: | |
| for child in input_dir.iterdir(): | |
| if child.is_file() and is_image_file(child): | |
| return True | |
| return False | |
| def get_relative_base_dir(input_dir: Path) -> Path: | |
| """ | |
| JSON์ image ๊ฐ์ 'ํด๋์คํด๋/์ด๋ฏธ์ง๋ช ' ํํ๋ก ๋ง๋ค๊ธฐ ์ํ ๊ธฐ์ค ๊ฒฝ๋ก๋ฅผ ์ ํ๋ค. | |
| ์์ 1) | |
| INPUT_IMAGE_DIR = /workspace/data/raw | |
| image file = /workspace/data/raw/pizza/hf_pizza_001.jpg | |
| relative base = /workspace/data/raw | |
| result = pizza/hf_pizza_001.jpg | |
| ์์ 2) | |
| INPUT_IMAGE_DIR = /workspace/data/raw/apple | |
| image file = /workspace/data/raw/apple/hf_apple_001.jpg | |
| relative base = /workspace/data/raw | |
| result = apple/hf_apple_001.jpg | |
| """ | |
| if INPUT_MODE == "raw": | |
| return input_dir | |
| if INPUT_MODE == "class": | |
| return input_dir.parent | |
| if INPUT_MODE == "auto": | |
| if has_direct_images(input_dir): | |
| return input_dir.parent | |
| return input_dir | |
| raise ValueError("INPUT_MODE์ 'auto', 'raw', 'class' ์ค ํ๋์ฌ์ผ ํฉ๋๋ค.") | |
| def collect_image_records(input_dir: str): | |
| input_path = Path(input_dir) | |
| if not input_path.exists(): | |
| raise FileNotFoundError(f"์ด๋ฏธ์ง ๊ฒฝ๋ก๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค: {input_path}") | |
| relative_base_dir = get_relative_base_dir(input_path) | |
| records = [] | |
| for image_path in sorted(input_path.rglob("*")): | |
| if not image_path.is_file(): | |
| continue | |
| if not is_image_file(image_path): | |
| continue | |
| relative_path = image_path.relative_to(relative_base_dir) | |
| relative_path_str = relative_path.as_posix() | |
| # image ๊ฐ์ด apple/xxx.jpg ๋ผ๋ฉด class๋ apple | |
| class_name = relative_path.parts[0] | |
| records.append({ | |
| "path": image_path, | |
| "image": relative_path_str, | |
| "class": class_name, | |
| }) | |
| if not records: | |
| raise ValueError(f"์บก์ ๋ํ ์ด๋ฏธ์ง๊ฐ ์์ต๋๋ค: {input_path}") | |
| return records | |
| # ========================================================= | |
| # 4. train / val / test split ๋ฐฐ์ | |
| # ========================================================= | |
| def assign_split(records): | |
| random.seed(RANDOM_SEED) | |
| class_map = defaultdict(list) | |
| for record in records: | |
| class_map[record["class"]].append(record) | |
| result = [] | |
| for class_name, items in class_map.items(): | |
| random.shuffle(items) | |
| total = len(items) | |
| train_count = int(total * SPLIT_RATIO["train"]) | |
| val_count = int(total * SPLIT_RATIO["val"]) | |
| for idx, item in enumerate(items): | |
| if idx < train_count: | |
| item["split"] = "train" | |
| elif idx < train_count + val_count: | |
| item["split"] = "val" | |
| else: | |
| item["split"] = "test" | |
| result.append(item) | |
| result.sort(key=lambda x: x["image"]) | |
| return result | |
| # ========================================================= | |
| # 5. ๋ชจ๋ธ ๋ก๋ | |
| # ========================================================= | |
| def load_model(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"[INFO] device: {device}") | |
| print(f"[INFO] model: {MODEL_NAME}") | |
| model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) | |
| processor = ViTImageProcessor.from_pretrained(MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| model.to(device) | |
| model.eval() | |
| return model, processor, tokenizer, device | |
| # ========================================================= | |
| # 6. ์บก์ ์์ฑ | |
| # ========================================================= | |
| def decode_output_ids(output_ids, tokenizer): | |
| captions = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| return [clean_caption(caption) for caption in captions] | |
| def generate_by_beam_search(images, model, processor, tokenizer, device): | |
| pixel_values = processor( | |
| images=images, | |
| return_tensors="pt" | |
| ).pixel_values.to(device) | |
| output_ids = model.generate( | |
| pixel_values, | |
| **GENERATION_CONFIG | |
| ) | |
| captions = decode_output_ids(output_ids, tokenizer) | |
| grouped = [] | |
| start = 0 | |
| for _ in images: | |
| end = start + CAPTIONS_PER_IMAGE | |
| grouped.append(captions[start:end]) | |
| start = end | |
| return grouped | |
| def generate_by_sampling(image, model, processor, tokenizer, device): | |
| pixel_values = processor( | |
| images=[image], | |
| return_tensors="pt" | |
| ).pixel_values.to(device) | |
| output_ids = model.generate( | |
| pixel_values, | |
| **SAMPLING_FALLBACK_CONFIG | |
| ) | |
| return decode_output_ids(output_ids, tokenizer) | |
| def complete_caption_count(captions, original_candidates): | |
| """ | |
| ๊ธฐ๋ณธ ๋ชฉํ: | |
| - ์ต๋ํ ์ค๋ณต ์๋ ์บก์ 3๊ฐ๋ฅผ ๋ง๋ ๋ค. | |
| ๋จ, ๋ชจ๋ธ์ด ๋น์ทํ ๋ฌธ์ฅ๋ง ๊ณ์ ๋ง๋ค๋ฉด 3๊ฐ๋ฅผ ๋ชป ์ฑ์ธ ์ ์๋ค. | |
| ์ด๋ FILL_WITH_DUPLICATES_IF_NEEDED=True์ด๋ฉด ์ค๋ณต์ ํ์ฉํด์ 3๊ฐ๋ฅผ ๋ง์ถ๋ค. | |
| """ | |
| captions = unique_captions(captions) | |
| if len(captions) >= CAPTIONS_PER_IMAGE: | |
| return captions[:CAPTIONS_PER_IMAGE] | |
| if not FILL_WITH_DUPLICATES_IF_NEEDED: | |
| return captions | |
| for caption in original_candidates: | |
| caption = clean_caption(caption) | |
| if caption: | |
| captions.append(caption) | |
| if len(captions) >= CAPTIONS_PER_IMAGE: | |
| break | |
| return captions[:CAPTIONS_PER_IMAGE] | |
| def generate_captions_for_batch(batch_records, model, processor, tokenizer, device): | |
| images = [] | |
| valid_records = [] | |
| for record in batch_records: | |
| image = load_image(record["path"]) | |
| if image is None: | |
| continue | |
| images.append(image) | |
| valid_records.append(record) | |
| if not images: | |
| return [] | |
| beam_caption_groups = generate_by_beam_search( | |
| images=images, | |
| model=model, | |
| processor=processor, | |
| tokenizer=tokenizer, | |
| device=device | |
| ) | |
| results = [] | |
| for record, image, beam_captions in zip(valid_records, images, beam_caption_groups): | |
| all_candidates = list(beam_captions) | |
| captions = unique_captions(beam_captions) | |
| if ENABLE_SAMPLING_FALLBACK: | |
| fallback_round = 0 | |
| while len(captions) < CAPTIONS_PER_IMAGE and fallback_round < MAX_FALLBACK_ROUNDS: | |
| sampled_captions = generate_by_sampling( | |
| image=image, | |
| model=model, | |
| processor=processor, | |
| tokenizer=tokenizer, | |
| device=device | |
| ) | |
| all_candidates.extend(sampled_captions) | |
| captions = unique_captions(captions + sampled_captions) | |
| fallback_round += 1 | |
| captions = complete_caption_count( | |
| captions=captions, | |
| original_candidates=all_candidates | |
| ) | |
| results.append({ | |
| "image": record["image"], | |
| "class": record["class"], | |
| "captions": captions, | |
| "split": record["split"], | |
| }) | |
| return results | |
| # ========================================================= | |
| # 7. JSON ์ ์ฅ | |
| # ========================================================= | |
| def save_json(data, output_path: str): | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| print(f"[DONE] JSON ์ ์ฅ ์๋ฃ: {output_path}") | |
| print(f"[DONE] ์ด ์ด๋ฏธ์ง ์: {len(data)}") | |
| # ========================================================= | |
| # 8. ์คํ | |
| # ========================================================= | |
| def main(): | |
| validate_config() | |
| records = collect_image_records(INPUT_IMAGE_DIR) | |
| records = assign_split(records) | |
| print(f"[INFO] ์บก์ ๋ ๋์ ์ด๋ฏธ์ง ์: {len(records)}") | |
| model, processor, tokenizer, device = load_model() | |
| results = [] | |
| for start in tqdm(range(0, len(records), BATCH_SIZE), desc="captioning"): | |
| end = start + BATCH_SIZE | |
| batch_records = records[start:end] | |
| batch_results = generate_captions_for_batch( | |
| batch_records=batch_records, | |
| model=model, | |
| processor=processor, | |
| tokenizer=tokenizer, | |
| device=device | |
| ) | |
| results.extend(batch_results) | |
| save_json(results, OUTPUT_JSON_PATH) | |
| if __name__ == "__main__": | |
| main() |