| <!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
| rendered properly in your Markdown viewer. | |
| --> | |
| # ์ด๋ฏธ์ง ์บก์ ๋[[image-captioning]] | |
| [[open-in-colab]] | |
| ์ด๋ฏธ์ง ์บก์ ๋(Image captioning)์ ์ฃผ์ด์ง ์ด๋ฏธ์ง์ ๋ํ ์บก์ ์ ์์ธกํ๋ ์์ ์ ๋๋ค. | |
| ์ด๋ฏธ์ง ์บก์ ๋์ ์๊ฐ ์ฅ์ ์ธ์ด ๋ค์ํ ์ํฉ์ ํ์ํ๋ ๋ฐ ๋์์ ์ค ์ ์๋๋ก ์๊ฐ ์ฅ์ ์ธ์ ๋ณด์กฐํ๋ ๋ฑ ์ค์ํ์์ ํํ ํ์ฉ๋ฉ๋๋ค. | |
| ๋ฐ๋ผ์ ์ด๋ฏธ์ง ์บก์ ๋์ ์ด๋ฏธ์ง๋ฅผ ์ค๋ช ํจ์ผ๋ก์จ ์ฌ๋๋ค์ ์ฝํ ์ธ ์ ๊ทผ์ฑ์ ๊ฐ์ ํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค. | |
| ์ด ๊ฐ์ด๋์์๋ ์๊ฐํ ๋ด์ฉ์ ์๋์ ๊ฐ์ต๋๋ค: | |
| * ์ด๋ฏธ์ง ์บก์ ๋ ๋ชจ๋ธ์ ํ์ธํ๋ํฉ๋๋ค. | |
| * ํ์ธํ๋๋ ๋ชจ๋ธ์ ์ถ๋ก ์ ์ฌ์ฉํฉ๋๋ค. | |
| ์์ํ๊ธฐ ์ ์ ํ์ํ ๋ชจ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์: | |
| ```bash | |
| pip install transformers datasets evaluate -q | |
| pip install jiwer -q | |
| ``` | |
| Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธํ๋ฉด ๋ชจ๋ธ์ ์ ๋ก๋ํ๊ณ ์ปค๋ฎค๋ํฐ์ ๊ณต์ ํ ์ ์์ต๋๋ค. | |
| ํ ํฐ์ ์ ๋ ฅํ์ฌ ๋ก๊ทธ์ธํ์ธ์. | |
| ```python | |
| from huggingface_hub import notebook_login | |
| notebook_login() | |
| ``` | |
| ## ํฌ์ผ๋ชฌ BLIP ์บก์ ๋ฐ์ดํฐ์ธํธ ๊ฐ์ ธ์ค๊ธฐ[[load-the-pokmon-blip-captions-dataset]] | |
| {์ด๋ฏธ์ง-์บก์ } ์์ผ๋ก ๊ตฌ์ฑ๋ ๋ฐ์ดํฐ์ธํธ๋ฅผ ๊ฐ์ ธ์ค๋ ค๋ฉด ๐ค Dataset ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| PyTorch์์ ์์ ๋ง์ ์ด๋ฏธ์ง ์บก์ ๋ฐ์ดํฐ์ธํธ๋ฅผ ๋ง๋ค๋ ค๋ฉด [์ด ๋ ธํธ๋ถ](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/GIT/Fine_tune_GIT_on_an_image_captioning_dataset.ipynb)์ ์ฐธ์กฐํ์ธ์. | |
| ```python | |
| from datasets import load_dataset | |
| ds = load_dataset("lambdalabs/pokemon-blip-captions") | |
| ds | |
| ``` | |
| ```bash | |
| DatasetDict({ | |
| train: Dataset({ | |
| features: ['image', 'text'], | |
| num_rows: 833 | |
| }) | |
| }) | |
| ``` | |
| ์ด ๋ฐ์ดํฐ์ธํธ๋ `image`์ `text`๋ผ๋ ๋ ํน์ฑ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. | |
| <Tip> | |
| ๋ง์ ์ด๋ฏธ์ง ์บก์ ๋ฐ์ดํฐ์ธํธ์๋ ์ด๋ฏธ์ง๋น ์ฌ๋ฌ ๊ฐ์ ์บก์ ์ด ํฌํจ๋์ด ์์ต๋๋ค. | |
| ์ด๋ฌํ ๊ฒฝ์ฐ, ์ผ๋ฐ์ ์ผ๋ก ํ์ต ์ค์ ์ฌ์ฉ ๊ฐ๋ฅํ ์บก์ ์ค์์ ๋ฌด์์๋ก ์ํ์ ์ถ์ถํฉ๋๋ค. | |
| </Tip> | |
| [`~datasets.Dataset.train_test_split`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ์ธํธ์ ํ์ต ๋ถํ ์ ํ์ต ๋ฐ ํ ์คํธ ์ธํธ๋ก ๋๋๋๋ค: | |
| ```python | |
| ds = ds["train"].train_test_split(test_size=0.1) | |
| train_ds = ds["train"] | |
| test_ds = ds["test"] | |
| ``` | |
| ํ์ต ์ธํธ์ ์ํ ๋ช ๊ฐ๋ฅผ ์๊ฐํํด ๋ด ์๋ค. | |
| Let's visualize a couple of samples from the training set. | |
| ```python | |
| from textwrap import wrap | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def plot_images(images, captions): | |
| plt.figure(figsize=(20, 20)) | |
| for i in range(len(images)): | |
| ax = plt.subplot(1, len(images), i + 1) | |
| caption = captions[i] | |
| caption = "\n".join(wrap(caption, 12)) | |
| plt.title(caption) | |
| plt.imshow(images[i]) | |
| plt.axis("off") | |
| sample_images_to_visualize = [np.array(train_ds[i]["image"]) for i in range(5)] | |
| sample_captions = [train_ds[i]["text"] for i in range(5)] | |
| plot_images(sample_images_to_visualize, sample_captions) | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_training_images_image_cap.png" alt="Sample training images"/> | |
| </div> | |
| ## ๋ฐ์ดํฐ์ธํธ ์ ์ฒ๋ฆฌ[[preprocess-the-dataset]] | |
| ๋ฐ์ดํฐ์ธํธ์๋ ์ด๋ฏธ์ง์ ํ ์คํธ๋ผ๋ ๋ ๊ฐ์ง ์์์ด ์๊ธฐ ๋๋ฌธ์, ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์์ ์ด๋ฏธ์ง์ ์บก์ ์ ๋ชจ๋ ์ ์ฒ๋ฆฌํฉ๋๋ค. | |
| ์ ์ฒ๋ฆฌ ์์ ์ ์ํด, ํ์ธํ๋ํ๋ ค๋ ๋ชจ๋ธ์ ์ฐ๊ฒฐ๋ ํ๋ก์ธ์ ํด๋์ค๋ฅผ ๊ฐ์ ธ์ต๋๋ค. | |
| ```python | |
| from transformers import AutoProcessor | |
| checkpoint = "microsoft/git-base" | |
| processor = AutoProcessor.from_pretrained(checkpoint) | |
| ``` | |
| ํ๋ก์ธ์๋ ๋ด๋ถ์ ์ผ๋ก ํฌ๊ธฐ ์กฐ์ ๋ฐ ํฝ์ ํฌ๊ธฐ ์กฐ์ ์ ํฌํจํ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ๋ฅผ ์ํํ๊ณ ์บก์ ์ ํ ํฐํํฉ๋๋ค. | |
| ```python | |
| def transforms(example_batch): | |
| images = [x for x in example_batch["image"]] | |
| captions = [x for x in example_batch["text"]] | |
| inputs = processor(images=images, text=captions, padding="max_length") | |
| inputs.update({"labels": inputs["input_ids"]}) | |
| return inputs | |
| train_ds.set_transform(transforms) | |
| test_ds.set_transform(transforms) | |
| ``` | |
| ๋ฐ์ดํฐ์ธํธ๊ฐ ์ค๋น๋์์ผ๋ ์ด์ ํ์ธํ๋์ ์ํด ๋ชจ๋ธ์ ์ค์ ํ ์ ์์ต๋๋ค. | |
| ## ๊ธฐ๋ณธ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ[[load-a-base-model]] | |
| ["microsoft/git-base"](https://huggingface.co/microsoft/git-base)๋ฅผ [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) ๊ฐ์ฒด๋ก ๊ฐ์ ธ์ต๋๋ค. | |
| ```python | |
| from transformers import AutoModelForCausalLM | |
| model = AutoModelForCausalLM.from_pretrained(checkpoint) | |
| ``` | |
| ## ํ๊ฐ[[evaluate]] | |
| ์ด๋ฏธ์ง ์บก์ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก [Rouge ์ ์](https://huggingface.co/spaces/evaluate-metric/rouge) ๋๋ [๋จ์ด ์ค๋ฅ์จ(Word Error Rate)](https://huggingface.co/spaces/evaluate-metric/wer)๋ก ํ๊ฐํฉ๋๋ค. | |
| ์ด ๊ฐ์ด๋์์๋ ๋จ์ด ์ค๋ฅ์จ(WER)์ ์ฌ์ฉํฉ๋๋ค. | |
| ์ด๋ฅผ ์ํด ๐ค Evaluate ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| WER์ ์ ์ฌ์ ์ ํ ์ฌํญ ๋ฐ ๊ธฐํ ๋ฌธ์ ์ ์ [์ด ๊ฐ์ด๋](https://huggingface.co/spaces/evaluate-metric/wer)๋ฅผ ์ฐธ์กฐํ์ธ์. | |
| ```python | |
| from evaluate import load | |
| import torch | |
| wer = load("wer") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| predicted = logits.argmax(-1) | |
| decoded_labels = processor.batch_decode(labels, skip_special_tokens=True) | |
| decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True) | |
| wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels) | |
| return {"wer_score": wer_score} | |
| ``` | |
| ## ํ์ต![[train!]] | |
| ์ด์ ๋ชจ๋ธ ํ์ธํ๋์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค. ์ด๋ฅผ ์ํด ๐ค [`Trainer`]๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| ๋จผ์ , [`TrainingArguments`]๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต ์ธ์๋ฅผ ์ ์ํฉ๋๋ค. | |
| ```python | |
| from transformers import TrainingArguments, Trainer | |
| model_name = checkpoint.split("/")[1] | |
| training_args = TrainingArguments( | |
| output_dir=f"{model_name}-pokemon", | |
| learning_rate=5e-5, | |
| num_train_epochs=50, | |
| fp16=True, | |
| per_device_train_batch_size=32, | |
| per_device_eval_batch_size=32, | |
| gradient_accumulation_steps=2, | |
| save_total_limit=3, | |
| evaluation_strategy="steps", | |
| eval_steps=50, | |
| save_strategy="steps", | |
| save_steps=50, | |
| logging_steps=50, | |
| remove_unused_columns=False, | |
| push_to_hub=True, | |
| label_names=["labels"], | |
| load_best_model_at_end=True, | |
| ) | |
| ``` | |
| ํ์ต ์ธ์๋ฅผ ๋ฐ์ดํฐ์ธํธ, ๋ชจ๋ธ๊ณผ ํจ๊ป ๐ค Trainer์ ์ ๋ฌํฉ๋๋ค. | |
| ```python | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_ds, | |
| eval_dataset=test_ds, | |
| compute_metrics=compute_metrics, | |
| ) | |
| ``` | |
| ํ์ต์ ์์ํ๋ ค๋ฉด [`Trainer`] ๊ฐ์ฒด์์ [`~Trainer.train`]์ ํธ์ถํ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค. | |
| ```python | |
| trainer.train() | |
| ``` | |
| ํ์ต์ด ์งํ๋๋ฉด์ ํ์ต ์์ค์ด ์ํํ๊ฒ ๊ฐ์ํ๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. | |
| ํ์ต์ด ์๋ฃ๋๋ฉด ๋ชจ๋ ์ฌ๋์ด ๋ชจ๋ธ์ ์ฌ์ฉํ ์ ์๋๋ก [`~Trainer.push_to_hub`] ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ธ์ ๊ณต์ ํ์ธ์: | |
| ```python | |
| trainer.push_to_hub() | |
| ``` | |
| ## ์ถ๋ก [[inference]] | |
| `test_ds`์์ ์ํ ์ด๋ฏธ์ง๋ฅผ ๊ฐ์ ธ์ ๋ชจ๋ธ์ ํ ์คํธํฉ๋๋ค. | |
| ```python | |
| from PIL import Image | |
| import requests | |
| url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png" | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| image | |
| ``` | |
| <div class="flex justify-center"> | |
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/test_image_image_cap.png" alt="Test image"/> | |
| </div> | |
| ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง๋ฅผ ์ค๋นํฉ๋๋ค. | |
| ```python | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| pixel_values = inputs.pixel_values | |
| ``` | |
| [`generate`]๋ฅผ ํธ์ถํ๊ณ ์์ธก์ ๋์ฝ๋ฉํฉ๋๋ค. | |
| ```python | |
| generated_ids = model.generate(pixel_values=pixel_values, max_length=50) | |
| generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| print(generated_caption) | |
| ``` | |
| ```bash | |
| a drawing of a pink and blue pokemon | |
| ``` | |
| ํ์ธํ๋๋ ๋ชจ๋ธ์ด ๊ฝค ๊ด์ฐฎ์ ์บก์ ์ ์์ฑํ ๊ฒ ๊ฐ์ต๋๋ค! | |