Spaces:
Sleeping
Sleeping
| import torch | |
| import requests | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, ResNetForImageClassification | |
| target_folder = "JungminChung/India_ResNet" | |
| def load_model_and_preprocessor(target_folder): | |
| model = ResNetForImageClassification.from_pretrained(target_folder) | |
| image_processor = AutoImageProcessor.from_pretrained(target_folder) | |
| return model, image_processor | |
| def fetch_image(url): | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36' | |
| } | |
| image_raw = requests.get(url, headers=headers, stream=True).raw | |
| image = Image.open(image_raw) | |
| return image | |
| def infer_image(image, model, image_processor, k): | |
| processed_img = image_processor(images=image.convert("RGB"), return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**processed_img) | |
| logits = outputs.logits | |
| prob = torch.nn.functional.softmax(logits, dim=-1) | |
| topk_prob, topk_indices = torch.topk(prob, k=k) | |
| res = "" | |
| for idx, (prob, index) in enumerate(zip(topk_prob[0], topk_indices[0])): | |
| res += f"{idx+1}. {model.config.id2label[index.item()]:<15} ({prob.item()*100:.2f} %) \n" | |
| return res | |
| def infer(url, k, target_folder=target_folder): | |
| try : | |
| image = fetch_image(url) | |
| model, image_processor = load_model_and_preprocessor(target_folder) | |
| res = infer_image(image, model, image_processor, k) | |
| except : | |
| image = Image.new('RGB', (224, 224)) | |
| res = "์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์ค๋๋ฐ ๋ฌธ์ ๊ฐ ์๋๋ด์. ๋ค๋ฅธ ์ด๋ฏธ์ง url๋ก ๋ค์ ์๋ํด์ฃผ์ธ์." | |
| return image, res | |
| demo = gr.Interface( | |
| fn=infer, | |
| inputs=[ | |
| gr.Textbox(value="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRpE-UHBp8ZufNUd3BKw8gtIxSe3IUwspOfqw&s", | |
| label="Image URL"), | |
| gr.Slider(minimum=0, maximum=20, step=1, value=3, label="์์ ๋ช๊ฐ๊น์ง ๋ณด์ฌ์ค๊น์?") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="์ ๋ ฅ ์ด๋ฏธ์ง"), | |
| gr.Textbox(label="์ข ๋ฅ (ํ๋ฅ )") | |
| ], | |
| ) | |
| demo.launch() | |
| # demo.launch(share=True) |