Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import json | |
| import os | |
| import sys # ํ์ผ ๋๋ฝ ์ ์ข ๋ฃ๋ฅผ ์ํด ์ถ๊ฐ | |
| # ============================================================================== | |
| # 0. ImageNet ํด๋์ค ์ด๋ฆ ๋ก๋ | |
| # ============================================================================== | |
| CLASS_MAP_FILENAME = 'labels_map.txt' | |
| class_name_map = None # ์ ์ญ ๋ณ์๋ก ์ด๊ธฐํ | |
| try: | |
| if not os.path.exists(CLASS_MAP_FILENAME): | |
| print(f"[์ค๋ฅ] ํด๋์ค ์ด๋ฆ ํ์ผ('{CLASS_MAP_FILENAME}')์ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| print("ํ์ผ์ ํ์ฌ ๋๋ ํ ๋ฆฌ์ ์ ์ฅํ๋์ง ํ์ธํด ์ฃผ์ธ์.") | |
| sys.exit(1) # ํ์ผ์ด ์์ผ๋ฉด ํ๋ก๊ทธ๋จ ์ข ๋ฃ | |
| # 1. ํ์ผ ๋ก๋ (JSON ํ์) | |
| with open(CLASS_MAP_FILENAME, 'r') as f: | |
| class_map_json = json.load(f) | |
| # 2. ์ ๊ณตํด์ฃผ์ ๋ก์ง ์ ์ฉ: ์ธ๋ฑ์ค 0๋ถํฐ 999๊น์ง ์ด๋ฆ๋ง ์ถ์ถํ์ฌ ๋ฆฌ์คํธ ์์ฑ | |
| # JSON ํ์ผ์ ํค๊ฐ ๋ฌธ์์ด์ด๋ฏ๋ก str(i)๋ก ์ ๊ทผํ๊ณ , ๊ฐ ๋ฆฌ์คํธ์ ๋ ๋ฒ์งธ ์์(์ด๋ฆ)๋ฅผ ๊ฐ์ ธ์ต๋๋ค. | |
| labels_list = [class_map_json[str(i)] for i in range(1000)] | |
| # 3. ์ธ๋ฑ์ค์ ์ด๋ฆ์ ๋งคํํ๋ ๋์ ๋๋ฆฌ๋ก ๋ณํ (๋์ค์ ํด๋์ค ID๋ก ์ด๋ฆ ์กฐํ ์ฉ์ด) | |
| class_name_map = {i: name for i, name in enumerate(labels_list)} | |
| print(f"ImageNet ํด๋์ค ์ด๋ฆ ({len(class_name_map)}๊ฐ) ๋ก๋ ์๋ฃ.") | |
| except Exception as e: | |
| print(f"[์ค๋ฅ] ํด๋์ค ํ์ผ ๋ก๋ ๋๋ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| sys.exit(1) # ์ค๋ฅ ๋ฐ์ ์ ํ๋ก๊ทธ๋จ ์ข ๋ฃ | |
| # ============================================================================== | |
| # 1. ์ค์ ๋ฐ ๋ชจ๋ธ ๋ก๋ | |
| # ============================================================================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"์ฌ์ฉ ์ฅ์น: {device}") | |
| # ImageNet์ผ๋ก ์ฌ์ ํ๋ จ๋ EfficientNetB0 ๋ชจ๋ธ ๋ก๋ | |
| print("์ฌ์ ํ๋ จ๋ EfficientNetB0 ๋ชจ๋ธ ๋ก๋ ์ค...") | |
| model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1) | |
| model.eval() # ํ๊ฐ ๋ชจ๋ ์ค์ | |
| model = model.to(device) | |
| print("๋ชจ๋ธ ๋ก๋ ๋ฐ ํ๊ฐ ๋ชจ๋ ์ค์ ์๋ฃ.") | |
| # ============================================================================== | |
| # 2. ํ์ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ์ ์ | |
| # ============================================================================== | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # ============================================================================== | |
| # 3. ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ฐ ์ถ๋ ฅ ํจ์ | |
| # ============================================================================== | |
| def classify_image(image_path_string): | |
| """ | |
| ์ฃผ์ด์ง ์ด๋ฏธ์ง ๊ฒฝ๋ก์ ํ์ผ์ EfficientNetB0 ๋ชจ๋ธ๋ก ๋ถ๋ฅํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํฉ๋๋ค. | |
| (ํด๋์ค ์ด๋ฆ ํฌํจ) | |
| """ | |
| try: | |
| # 1. ์ด๋ฏธ์ง ๋ก๋ ๋ฐ RGB ๋ณํ | |
| img = Image.open(image_path_string).convert('RGB') | |
| print(f"\n[INFO] ์ด๋ฏธ์ง ๋ก๋ ์ฑ๊ณต: {image_path_string}") | |
| input_tensor = preprocess(img) | |
| input_batch = input_tensor.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| probabilities = F.softmax(output[0], dim=0) | |
| top_prob, top_catid = torch.topk(probabilities, 5) | |
| print("\n--- ๋ถ๋ฅ ๊ฒฐ๊ณผ (Top-5) ---") | |
| for i in range(top_prob.size(0)): | |
| idx = top_catid[i].item() | |
| # ํด๋์ค ์ด๋ฆ ๋งคํ ์ ์ฉ: ๋ก๋๋ ๋์ ๋๋ฆฌ ์ฌ์ฉ | |
| class_name = class_name_map.get(idx, f"์ ์ ์๋ ํด๋์ค (ID: {idx})") | |
| print(f"์์ {i+1}:") | |
| print(f" - ํด๋์ค ์ด๋ฆ: **{class_name}**") | |
| print(f" - ํด๋์ค ์ธ๋ฑ์ค (ID): {idx}") | |
| print(f" - ํ๋ฅ : {top_prob[i].item():.4f}") | |
| except FileNotFoundError: | |
| print(f"\n[์ค๋ฅ] ์ด๋ฏธ์ง ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {image_path_string}") | |
| print("๊ฒฝ๋ก๋ฅผ ๋ค์ ํ์ธํด์ฃผ์ธ์.") | |
| except Exception as e: | |
| print(f"\n[์ค๋ฅ] ๋ถ๋ฅ ์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}") | |
| # --- ์คํ --- | |
| # ๋ถ๋ฅํ ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฌธ์์ด๋ก ์ง์ (์ฌ์ฉ์ ํ๊ฒฝ์ ๋ง๊ฒ ์์ ํ์!) | |
| CLASSIFY_TARGET_PATH = 'D:/pictures/muffin1.png' | |
| # ํจ์ ์คํ | |
| classify_image(CLASSIFY_TARGET_PATH) |