binaychandra commited on
Commit
4533cc6
·
1 Parent(s): 30bfd08

mobilenet update

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,9 +2,10 @@ import gradio as gr
2
  import torch
3
  from torchvision import models, transforms
4
  from PIL import Image
 
5
 
6
  # Load the pre-trained ResNet model
7
- model = models.resnet50(pretrained=True)
8
  model.eval()
9
 
10
  # Define the transformation for input images
@@ -17,7 +18,8 @@ preprocess = transforms.Compose([
17
 
18
  # Define the labels for ImageNet classes (you may need to adjust this based on your model)
19
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
20
- labels = gr.utils.get_file(LABELS_URL, cache=True).read_text().splitlines()
 
21
 
22
  # Function to perform image classification
23
  def classify_image(input_image):
 
2
  import torch
3
  from torchvision import models, transforms
4
  from PIL import Image
5
+ import requests
6
 
7
  # Load the pre-trained ResNet model
8
+ model = models.MobileNet_V3_Small_Weights(pretrained=True)
9
  model.eval()
10
 
11
  # Define the transformation for input images
 
18
 
19
  # Define the labels for ImageNet classes (you may need to adjust this based on your model)
20
  LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
21
+ response = requests.get(LABELS_URL)
22
+ labels = response.text.splitlines()
23
 
24
  # Function to perform image classification
25
  def classify_image(input_image):
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gradio
2
  pillow
3
  torch
 
4
  torchvision
 
1
  gradio
2
  pillow
3
  torch
4
+ requests
5
  torchvision