| | --- |
| | license: apache-2.0 |
| | --- |
| | |
| | ๋ก๋งจ์ค ์ค์บ ์ฌ์ง๊ณผ, ๊ทธ๋ฅ ์ฌ์ง์ ๊ตฌ๋ณํ ์ ์๋ ViT ๋ชจ๋ธ ์
๋๋ค. |
| | ๊ธฐ์กด์ CNN ๋ชจ๋ธ์ ๋นํด ํจ์ ์ฑ๋ฅ์ด ์ข์ต๋๋ค. |
| | ์ถํ ๋ฐ์ดํฐ๋ฅผ ์ถ๊ฐํด ์ฑ๋ฅ์ ๋์ฑ ๋๋ฆด๊ฒ ์
๋๋ค. |
| | ์ฌ์ฉ ์ฝ๋๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค. |
| |
|
| | ```python |
| | import torch |
| | from transformers import ViTForImageClassification, ViTFeatureExtractor |
| | from PIL import Image |
| | |
| | # Hugging Face์์ ๋ชจ๋ธ ๋ฐ ํน์ง ์ถ์ถ๊ธฐ ๋ถ๋ฌ์ค๊ธฐ |
| | model = ViTForImageClassification.from_pretrained("gihakkk/vit_modle") |
| | feature_extractor = ViTFeatureExtractor.from_pretrained("gihakkk/vit_modle") |
| | |
| | # ์๋ก์ด ์ด๋ฏธ์ง ์์ธก ํจ์ ์ ์ |
| | def predict_image(image_path): |
| | # ์ด๋ฏธ์ง๋ฅผ ๋ก๋ํ๊ณ RGB๋ก ๋ณํ |
| | image = Image.open(image_path).convert("RGB") |
| | |
| | # ์ด๋ฏธ์ง๋ฅผ ํน์ง ์ถ์ถ๊ธฐ๋ก ์ ์ฒ๋ฆฌํ์ฌ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํ |
| | inputs = feature_extractor(images=image, return_tensors="pt") |
| | |
| | # ์์ธก ์ํ |
| | with torch.no_grad(): |
| | outputs = model(**inputs).logits |
| | predicted_class = torch.argmax(outputs, dim=-1).item() |
| | |
| | return "๊ทธ๋ฅ ์ฌ์ง" if predicted_class == 1 else "๋ก๋งจ์ค ์ค์บ ์ฌ์ง" |
| | |
| | # ์์ธก ์์ |
| | image_path = r'path\to\your\img.jpg' |
| | result = predict_image(image_path) |
| | print(result) |
| | |
| | ``` |