Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[ ]: | |
| # This Python 3 environment comes with many helpful analytics libraries installed | |
| # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python | |
| # For example, here's several helpful packages to load | |
| import numpy as np # linear algebra | |
| import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
| # Input data files are available in the read-only "../input/" directory | |
| # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory | |
| # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" | |
| # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session | |
| # In[1]: | |
| # import fastai | |
| # print(fastai.__version__) | |
| # In[2]: | |
| #/export | |
| from fastai.vision.all import * | |
| import gradio as gr | |
| import torch | |
| learn = load_learner('model-2.pkl') | |
| # Use valid batch to get input size and device for manual inference (avoids learn.predict | |
| # passing (Image, dict) into transforms and triggering TypeError: 'PILImage' + 'dict'). | |
| try: | |
| _inf_batch = next(iter(learn.dls.valid)) | |
| _INFERENCE_DEVICE = _inf_batch[0].device | |
| _INFERENCE_SIZE = _inf_batch[0].shape[-1] | |
| except Exception: | |
| _INFERENCE_DEVICE = torch.device("cpu") | |
| _INFERENCE_SIZE = 224 | |
| #Pydantic Warnings | |
| # UnsupportedFieldAttributeWarning: The 'repr' attribute with value False ... | |
| # UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True ... | |
| # These are coming from pydantic, a dependency used by some fastai internals or Gradio. | |
| # They don’t affect your model predictions. | |
| # It’s just saying certain metadata on fields is ignored in Python 3.12. | |
| # You can safely ignore these. | |
| # Pickle Security Warning | |
| # UserWarning: load_learner uses Python's insecure pickle module ... | |
| # This is fastai reminding you that load_learner uses pickle, which can execute arbitrary code. | |
| # Safe for you because this is your own export.pkl file. | |
| # The warning is just a general caution; it doesn’t break your notebook. | |
| # ✅ Only a concern if you were loading a pickle from an untrusted source. | |
| # In[3]: | |
| #/export | |
| # Function for multi-class flower predictions | |
| # This returns a Python dictionary with class names as keys and their corresponding probabilities as values. | |
| # def classify_flower(img): | |
| # pred, pred_idx, probs = learn.predict(img) | |
| # return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))} | |
| # def classify_flower(img): | |
| # pred, pred_idx, probs = learn.predict(img) | |
| # # Convert to dict and sort by probability descending | |
| # result = {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))} | |
| # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True)) | |
| # def classify_flower(img): | |
| # pred, pred_idx, probs = learn.predict(img) | |
| # return {learn.dls.vocab[i]: float(probs[i]) for i in range(len(probs))} | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("## 🌸 16 Flower Classifier") | |
| # with gr.Row(): | |
| # image_input = gr.Image(type="pil") | |
| # label_output = gr.Label() | |
| # submit_btn = gr.Button("Submit") | |
| # submit_btn.click(fn=classify_flower, inputs=image_input, outputs=label_output) | |
| # def classify_flower(img): | |
| # img = PILImage.create(img) | |
| # pred, pred_idx, probs = learn.predict(img) | |
| def _preprocess_for_learner(pil_img): | |
| """Turn PIL image into a batch tensor matching the learner's expected size and normalization.""" | |
| from torchvision import transforms as T | |
| # Match typical fastai/ImageNet preprocessing | |
| transform = T.Compose([ | |
| T.Resize((_INFERENCE_SIZE, _INFERENCE_SIZE)), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| x = transform(pil_img).unsqueeze(0).to(_INFERENCE_DEVICE) | |
| return x | |
| def classify_flower(img): | |
| if img is None: | |
| return None | |
| # Normalize Gradio input to PIL (can be dict with "path" or "image", or numpy, or PIL). | |
| if isinstance(img, dict): | |
| path = img.get("path") | |
| if isinstance(path, str): | |
| from PIL import Image as PILImageModule | |
| img = PILImageModule.open(path).convert("RGB") | |
| elif img.get("image") is not None: | |
| img = img["image"] | |
| else: | |
| return None | |
| from PIL import Image as PILImageModule | |
| if not isinstance(img, PILImageModule.Image): | |
| img = np.asarray(img) | |
| if img.ndim == 2: | |
| img = np.stack([img] * 3, axis=-1) | |
| img = PILImageModule.fromarray(img).convert("RGB") | |
| else: | |
| img = img.convert("RGB") | |
| # Bypass learn.predict() to avoid (PILImage, dict) in the transform pipeline. | |
| x = _preprocess_for_learner(img) | |
| learn.model.eval() | |
| with torch.no_grad(): | |
| logits = learn.model(x) | |
| probs = torch.softmax(logits, dim=1)[0].cpu().numpy() | |
| return { | |
| learn.dls.vocab[i]: float(probs[i]) | |
| for i in range(len(probs)) | |
| } | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🌸 16 Flower Classifier") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil") | |
| label_output = gr.Label(num_top_classes=16) | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click( | |
| fn=classify_flower, | |
| inputs=image_input, | |
| outputs=label_output | |
| ) | |
| # def classify_flower(img): | |
| # pred, pred_idx, probs = learn.predict(img) | |
| # # gr.Label expects a dict {label: confidence} | |
| # result = { | |
| # learn.dls.vocab[i]: float(probs[i]) | |
| # for i in range(len(probs)) | |
| # } | |
| # return result | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("## 🌸 16 Flower Classifier") | |
| # with gr.Row(): | |
| # image_input = gr.Image(type="pil") | |
| # label_output = gr.Label() | |
| # submit_btn = gr.Button("Submit") | |
| # submit_btn.click(fn=classify_flower, inputs=image_input, outputs=label_output) | |
| demo.launch() | |
| # ## EXPORT | |
| # In[19]: | |