Spaces:
Build error
Build error
| import math | |
| import torch | |
| from timm import create_model | |
| from torchvision import transforms | |
| from PIL import Image | |
| # ๊ธฐ๋ณธ ํฌ๋กญ ๋น์จ | |
| DEFAULT_CROP_PCT = 0.875 | |
| # EfficientNet-B0 ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋ | |
| weights_path = "./weights/resnest101e.in1k_weight_Pear_classification.pt" # ๋ก์ปฌ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก | |
| model_name = "resnest101e" | |
| model = create_model(model_name, pretrained=False,num_classes=9) # ์ฌ์ ํ์ต ๋ก๋ ์๋ต | |
| #model.classifier = torch.nn.Linear(model.classifier.in_features, 2) # ์ด์ง ๋ถ๋ฅ๋ก ์์ | |
| model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # ๋ก์ปฌ ๊ฐ์ค์น ๋ก๋ | |
| model.eval() # ํ๊ฐ ๋ชจ๋ ์ค์ | |
| # ํด๋์ค ์ด๋ฆ ์๋ ์ง์ | |
| # class_labels = ["Abormal Pear", "Normal Pear"] # ๋ฐ์ดํฐ์ ์ ๋ง๊ฒ ์์ | |
| class_labels = ["์ ์", "ํ์ฑ๋ณ","๊ณผํผ์ผ๋ฃฉ","๋ณต์ญ์ ์๋๋ฐฉ","๋ณต์ญ์ ์ผ์๋๋ฐฉ","๋ฐฐ ๊น์ง๋ฒ๋ ","์๋ง์ด ๋๋ฐฉ๋ฅ", "๊ธฐํ","๊ณผํผํ๋ณ"] | |
| # ์ ์ฒ๋ฆฌ ํจ์ | |
| def transforms_imagenet_eval( | |
| img_path: str, | |
| img_size: int = 224, | |
| crop_pct: float = DEFAULT_CROP_PCT, | |
| mean: tuple = (0.485, 0.456, 0.406), | |
| std: tuple = (0.229, 0.224, 0.225), | |
| normalize: bool = True, | |
| ): | |
| """ | |
| ImageNet ์คํ์ผ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํจ์. | |
| Args: | |
| img_path (str): ์ด๋ฏธ์ง ๊ฒฝ๋ก. | |
| img_size (int): ํฌ๋กญ ํฌ๊ธฐ. | |
| crop_pct (float): ํฌ๋กญ ๋น์จ. | |
| mean (tuple): ์ ๊ทํ ํ๊ท . | |
| std (tuple): ์ ๊ทํ ํ์คํธ์ฐจ. | |
| normalize (bool): ์ ๊ทํ ์ฌ๋ถ. | |
| Returns: | |
| torch.Tensor: ์ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ํ ์. | |
| """ | |
| img = Image.open(img_path).convert("RGB") # ์ด๋ฏธ์ง ๋ก๋ ๋ฐ RGB ๋ณํ | |
| scale_size = math.floor(img_size / crop_pct) # ๋ฆฌ์ฌ์ด์ฆ ํฌ๊ธฐ ๊ณ์ฐ | |
| # Transform ์ค์ | |
| tfl = [ | |
| transforms.Resize((scale_size, scale_size), interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(img_size), | |
| transforms.ToTensor(), | |
| ] | |
| if normalize: | |
| tfl += [transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))] | |
| transform = transforms.Compose(tfl) | |
| return transform(img) | |
| # ์ถ๋ก ํจ์ | |
| def predict(image_path: str): | |
| """ | |
| ์ฃผ์ด์ง ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฐ์ ๋ชจ๋ธ ์ถ๋ก ์ ์ํ. | |
| Args: | |
| image_path (str): ์ ๋ ฅ ์ด๋ฏธ์ง ๊ฒฝ๋ก. | |
| Returns: | |
| str: ๋ชจ๋ธ ์์ธก ๊ฒฐ๊ณผ. | |
| """ | |
| # ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ | |
| input_tensor = transforms_imagenet_eval( | |
| img_path=image_path, | |
| img_size=224, | |
| normalize=True | |
| ).unsqueeze(0) # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ | |
| # ๋ชจ๋ธ ์ถ๋ก | |
| with torch.no_grad(): | |
| prediction = model(input_tensor) | |
| probs = torch.nn.functional.softmax(prediction[0], dim=-1) | |
| confidences = {class_labels[i]: float(probs[i]) for i in range(9)} | |
| # ์์ธก ๊ฒฐ๊ณผ ๋ฐํ | |
| return confidences |