Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import random | |
| from pathlib import Path | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from dotenv import load_dotenv | |
| from transformers import AutoProcessor, Florence2ForConditionalGeneration | |
| # ========================================================= | |
| # 1. ์ค์ ๊ฐ | |
| # ========================================================= | |
| # ์ ์ฒด ํด๋์ค ์บก์ ๋: "data/raw" | |
| # ํน์ ํด๋์ค๋ง ์บก์ ๋: "data/raw/apple" | |
| INPUT_IMAGE_DIR = "data/raw" | |
| # image ๊ฐ์ "pizza/hf_pizza_001.jpg" ํํ๋ก ๋ง๋ค๊ธฐ ์ํ ๊ธฐ์ค ๊ฒฝ๋ก | |
| DATA_RAW_ROOT = "data/raw" | |
| # ๊ฒฐ๊ณผ JSON ์ ์ฅ ๊ฒฝ๋ก | |
| OUTPUT_JSON_PATH = "data/annotations/captions_flo_all.json" | |
| # transformers 5.7.0์์๋ florence-community ๋ชจ๋ธ ์ฌ์ฉ ๊ถ์ฅ | |
| # base-ft: ๊ฐ๋ณ๊ณ ๋ค์ด์คํธ๋ฆผ task์ fine-tuning๋ ๋ชจ๋ธ | |
| # large-ft: ๋ ๋ฌด๊ฒ์ง๋ง ํ์ง์ด ๋ ์ข์ ์ ์์ | |
| MODEL_ID = "florence-community/Florence-2-base-ft" | |
| # MODEL_ID = "florence-community/Florence-2-large-ft" | |
| # .env ํ์ผ์์ ์ฝ์ Hugging Face ํ ํฐ ์ด๋ฆ | |
| # ๊ณต๊ฐ ๋ชจ๋ธ์ด๋ฉด ์์ด๋ ๋์ํ ์ ์์ง๋ง, ํ ํฐ์ ๋ฃ์ด๋๋ ํธ์ด ์์ ์ ์ ๋๋ค. | |
| HF_TOKEN_ENV_NAME = "HF_TOKEN" | |
| # split ๋น์จ: ๊ธฐ๋ณธ 7 : 1.5 : 1.5 | |
| TRAIN_RATIO = 0.7 | |
| VAL_RATIO = 0.15 | |
| TEST_RATIO = 0.15 | |
| # split ์ฌํ์ ์ํ seed | |
| RANDOM_SEED = 42 | |
| # ์ด๋ฏธ์ง๋น ์บก์ 3๊ฐ ์์ฑ | |
| # Florence-2 ๋ฌธ์์์ ์ง์ํ๋ caption task์ ๋๋ค. | |
| CAPTION_TASKS = [ | |
| "<CAPTION>", | |
| "<DETAILED_CAPTION>", | |
| "<MORE_DETAILED_CAPTION>", | |
| ] | |
| # ์์ฑ ์ต์ | |
| NUM_BEAMS = 3 | |
| MAX_NEW_TOKENS = 64 | |
| # ๋ช ์ฅ๋ง๋ค ์ค๊ฐ ์ ์ฅํ ์ง | |
| SAVE_EVERY = 220 | |
| # ์ด๋ฏธ JSON์ ์๋ ์ด๋ฏธ์ง๋ ๊ฑด๋๋ธ์ง ์ฌ๋ถ | |
| SKIP_ALREADY_DONE = True | |
| # ํ์ฉ ์ด๋ฏธ์ง ํ์ฅ์ | |
| IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".webp", ".bmp"] | |
| # ========================================================= | |
| # 2. ์ด๋ฏธ์ง ๋ชฉ๋ก ๊ฐ์ ธ์ค๊ธฐ | |
| # ========================================================= | |
| def get_image_list(): | |
| input_dir = Path(INPUT_IMAGE_DIR).resolve() | |
| raw_root = Path(DATA_RAW_ROOT).resolve() | |
| if not input_dir.exists(): | |
| raise FileNotFoundError(f"์ ๋ ฅ ๊ฒฝ๋ก๊ฐ ์์ต๋๋ค: {input_dir}") | |
| image_list = [] | |
| for image_path in sorted(input_dir.rglob("*")): | |
| if image_path.suffix.lower() not in IMAGE_EXTENSIONS: | |
| continue | |
| # ์: | |
| # /workspace/data/raw/pizza/hf_pizza_001.jpg | |
| # -> pizza/hf_pizza_001.jpg | |
| relative_image_path = image_path.resolve().relative_to(raw_root).as_posix() | |
| # ์: | |
| # pizza/hf_pizza_001.jpg | |
| # -> pizza | |
| class_name = relative_image_path.split("/")[0] | |
| image_list.append({ | |
| "path": image_path, | |
| "image": relative_image_path, | |
| "class": class_name, | |
| }) | |
| return image_list | |
| # ========================================================= | |
| # 3. train / val / test ๋๋๊ธฐ | |
| # ========================================================= | |
| def add_split(image_list): | |
| random.seed(RANDOM_SEED) | |
| total_ratio = TRAIN_RATIO + VAL_RATIO + TEST_RATIO | |
| result = [] | |
| # ํด๋์ค๋ณ๋ก ์ด๋ฏธ์ง ๋ชจ์ผ๊ธฐ | |
| class_map = {} | |
| for item in image_list: | |
| class_name = item["class"] | |
| if class_name not in class_map: | |
| class_map[class_name] = [] | |
| class_map[class_name].append(item) | |
| # ํด๋์ค๋ณ๋ก train / val / test ๋๋๊ธฐ | |
| for class_name, items in class_map.items(): | |
| random.shuffle(items) | |
| total_count = len(items) | |
| train_count = round(total_count * TRAIN_RATIO / total_ratio) | |
| val_count = round(total_count * VAL_RATIO / total_ratio) | |
| for index, item in enumerate(items): | |
| if index < train_count: | |
| split = "train" | |
| elif index < train_count + val_count: | |
| split = "val" | |
| else: | |
| split = "test" | |
| item["split"] = split | |
| result.append(item) | |
| return result | |
| # ========================================================= | |
| # 4. Florence-2 ๋ชจ๋ธ ์ค๋น | |
| # ========================================================= | |
| def load_model(): | |
| load_dotenv() | |
| hf_token = os.getenv(HF_TOKEN_ENV_NAME) | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| # GPU๊ฐ bfloat16์ ์ง์ํ๋ฉด bfloat16 ์ฌ์ฉ | |
| # ์๋๋ฉด float16 ์ฌ์ฉ | |
| if torch.cuda.is_bf16_supported(): | |
| torch_dtype = torch.bfloat16 | |
| else: | |
| torch_dtype = torch.float16 | |
| else: | |
| device = "cpu" | |
| torch_dtype = torch.float32 | |
| print(f"device: {device}") | |
| print(f"dtype: {torch_dtype}") | |
| print(f"model: {MODEL_ID}") | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| token=hf_token, | |
| ) | |
| model = Florence2ForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| dtype=torch_dtype, | |
| token=hf_token, | |
| ).to(device) | |
| model.eval() | |
| return model, processor, device, torch_dtype | |
| # ========================================================= | |
| # 5. ์ด๋ฏธ์ง 1์ฅ ์บก์ ๋ | |
| # ========================================================= | |
| def make_caption(image, task, model, processor, device, torch_dtype): | |
| inputs = processor( | |
| text=task, | |
| images=image, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(device, torch_dtype) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| num_beams=NUM_BEAMS, | |
| do_sample=False, | |
| ) | |
| generated_text = processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=False, | |
| )[0] | |
| parsed_result = processor.post_process_generation( | |
| generated_text, | |
| task=task, | |
| image_size=image.size, | |
| ) | |
| caption = parsed_result.get(task, "") | |
| if not isinstance(caption, str): | |
| caption = str(caption) | |
| return caption.strip() | |
| def make_three_captions(image_path, model, processor, device, torch_dtype): | |
| image = Image.open(image_path).convert("RGB") | |
| captions = [] | |
| for task in CAPTION_TASKS: | |
| caption = make_caption( | |
| image=image, | |
| task=task, | |
| model=model, | |
| processor=processor, | |
| device=device, | |
| torch_dtype=torch_dtype, | |
| ) | |
| captions.append(caption) | |
| return captions | |
| # ========================================================= | |
| # 6. ๊ธฐ์กด JSON ์ฝ๊ธฐ / ์ ์ฅํ๊ธฐ | |
| # ========================================================= | |
| def load_existing_result(): | |
| output_path = Path(OUTPUT_JSON_PATH) | |
| if not output_path.exists(): | |
| return {} | |
| with output_path.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| result = {} | |
| for item in data: | |
| result[item["image"]] = item | |
| return result | |
| def save_result(result_map): | |
| output_path = Path(OUTPUT_JSON_PATH) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| result_list = list(result_map.values()) | |
| result_list.sort(key=lambda x: x["image"]) | |
| with output_path.open("w", encoding="utf-8") as f: | |
| json.dump(result_list, f, ensure_ascii=False, indent=4) | |
| # ========================================================= | |
| # 7. ์คํ | |
| # ========================================================= | |
| def main(): | |
| print("์ด๋ฏธ์ง ๋ชฉ๋ก์ ์ฝ๋ ์ค์ ๋๋ค.") | |
| image_list = get_image_list() | |
| image_list = add_split(image_list) | |
| print(f"์ด ์ด๋ฏธ์ง ์: {len(image_list)}") | |
| result_map = load_existing_result() | |
| model, processor, device, torch_dtype = load_model() | |
| new_count = 0 | |
| skip_count = 0 | |
| fail_count = 0 | |
| for item in tqdm(image_list): | |
| image_key = item["image"] | |
| if SKIP_ALREADY_DONE and image_key in result_map: | |
| skip_count += 1 | |
| continue | |
| try: | |
| captions = make_three_captions( | |
| image_path=item["path"], | |
| model=model, | |
| processor=processor, | |
| device=device, | |
| torch_dtype=torch_dtype, | |
| ) | |
| result_map[image_key] = { | |
| "image": item["image"], | |
| "class": item["class"], | |
| "captions": captions, | |
| "split": item["split"], | |
| } | |
| new_count += 1 | |
| if new_count % SAVE_EVERY == 0: | |
| save_result(result_map) | |
| except Exception as e: | |
| fail_count += 1 | |
| print(f"\n์คํจํ ์ด๋ฏธ์ง: {item['path']}") | |
| print(f"์๋ฌ ๋ด์ฉ: {e}") | |
| save_result(result_map) | |
| print("\n์บก์ ๋ ์๋ฃ") | |
| print(f"์๋ก ์ฒ๋ฆฌํ ์ด๋ฏธ์ง ์: {new_count}") | |
| print(f"๊ฑด๋๋ด ์ด๋ฏธ์ง ์: {skip_count}") | |
| print(f"์คํจํ ์ด๋ฏธ์ง ์: {fail_count}") | |
| print(f"์ ์ฅ ์์น: {OUTPUT_JSON_PATH}") | |
| if __name__ == "__main__": | |
| main() |