Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import ViTForImageClassification, ViTModel, ViTImageProcessor | |
| from PIL import Image | |
| from sklearn.preprocessing import LabelEncoder | |
| import pandas as pd | |
| from PIL import Image, ExifTags | |
| # 中文問候函數 | |
| def greet(name): | |
| return f"你好,{name}!!" | |
| # 圖像預處理 | |
| from PIL import Image, ExifTags | |
| def preprocess_image(image): | |
| """ | |
| 將輸入圖像轉換為模型可接受的 tensor | |
| 支援 iPhone 圖片自動旋轉 | |
| """ | |
| # 轉成 PIL Image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(np.uint8(image)) | |
| elif not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| # 修正 iPhone EXIF 方向 | |
| try: | |
| for orientation in ExifTags.TAGS.keys(): | |
| if ExifTags.TAGS[orientation]=='Orientation': | |
| break | |
| exif=dict(image._getexif().items()) | |
| if exif[orientation] == 3: | |
| image=image.rotate(180, expand=True) | |
| elif exif[orientation] == 6: | |
| image=image.rotate(270, expand=True) | |
| elif exif[orientation] == 8: | |
| image=image.rotate(90, expand=True) | |
| except: | |
| # 沒有 EXIF 資訊就直接跳過 | |
| pass | |
| # 轉成 RGB | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # 使用 feature extractor | |
| inputs = feature_extractor(images=[image]) | |
| image_tensor = torch.tensor(inputs['pixel_values'][0], dtype=torch.float32) | |
| return image_tensor | |
| # 模型預測 | |
| def predict(image_tensor, top_k=5): | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model(pixel_values=image_tensor.unsqueeze(0)) | |
| logits = outputs.logits.numpy() | |
| top_indices = np.argsort(logits, axis=1)[:, ::-1][:, :top_k] | |
| top_probs = np.sort(logits, axis=1)[:, ::-1][:, :top_k] | |
| data = [] | |
| for i in range(top_k): | |
| class_name = encoder.inverse_transform([top_indices[0][i]])[0] | |
| probability = round(float(top_probs[0][i]), 4) | |
| data.append([i+1, class_name, probability]) | |
| df = pd.DataFrame(data, columns=["排名", "藥丸名稱", "機率"]) | |
| return df | |
| # 主函數(回傳圖片 + 表格) | |
| def classify_pill(file, top_k: int = 5): | |
| if file is None: | |
| return None, "⚠️ 請上傳一張藥丸圖片!" | |
| try: | |
| image_tensor = preprocess_image(file) | |
| df = predict(image_tensor, top_k) | |
| # 將輸入圖像縮圖回傳 | |
| if isinstance(file, np.ndarray): | |
| img_display = Image.fromarray(np.uint8(file)) | |
| else: | |
| img_display = file | |
| return img_display, df | |
| except Exception as e: | |
| return None, f"❌ 預測失敗,錯誤訊息:{e}" | |
| # 載入 LabelEncoder | |
| encoder = LabelEncoder() | |
| encoder.classes_ = np.load('encoder.npy', allow_pickle=True) | |
| # 載入模型 | |
| pretrained_model = ViTModel.from_pretrained('pillIdentifierAI/pillIdentifier') | |
| feature_extractor = ViTImageProcessor( | |
| image_size=224, | |
| do_resize=True, | |
| do_normalize=True, | |
| do_rescale=False, | |
| image_mean=[0.5, 0.5, 0.5], | |
| image_std=[0.5, 0.5, 0.5], | |
| ) | |
| config = pretrained_model.config | |
| config.num_labels = len(encoder.classes_) | |
| model = ViTForImageClassification(config) | |
| model.vit = pretrained_model | |
| model.eval() | |
| # 啟動 Gradio | |
| iface = gr.Interface( | |
| fn=classify_pill, | |
| inputs=[ | |
| gr.Image(type="numpy", label="📸 上傳藥丸圖片"), | |
| gr.Slider(1, 10, value=5, step=1, label="🔢 顯示前幾個預測結果") | |
| ], | |
| outputs=[ | |
| gr.Image(label="🔍 上傳圖片預覽"), | |
| gr.Dataframe(label="📝 預測結果(中文表格)", headers=["排名", "藥丸名稱", "機率"]) | |
| ], | |
| title="藥丸辨識器 💊", | |
| description="上傳藥丸圖片,我們會顯示圖片預覽並以表格形式列出前幾個可能的藥名與機率。" | |
| ) | |
| iface.launch(share=True) | |