binaychandra commited on
Commit
30bfd08
·
1 Parent(s): 140e961

resnet app

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +50 -3
  3. requirements.txt +4 -1
.gitignore CHANGED
@@ -0,0 +1 @@
 
 
1
+ temp.txt
app.py CHANGED
@@ -1,7 +1,54 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello there, " + name + "!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
 
6
+ # Load the pre-trained ResNet model
7
+ model = models.resnet50(pretrained=True)
8
+ model.eval()
9
 
10
+ # Define the transformation for input images
11
+ preprocess = transforms.Compose([
12
+ transforms.Resize(256),
13
+ transforms.CenterCrop(224),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
16
+ ])
17
+
18
+ # Define the labels for ImageNet classes (you may need to adjust this based on your model)
19
+ LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
20
+ labels = gr.utils.get_file(LABELS_URL, cache=True).read_text().splitlines()
21
+
22
+ # Function to perform image classification
23
+ def classify_image(input_image):
24
+ # Preprocess the image
25
+ input_tensor = preprocess(input_image)
26
+ input_batch = input_tensor.unsqueeze(0)
27
+
28
+ # Make predictions
29
+ with torch.no_grad():
30
+ output = model(input_batch)
31
+
32
+ # Get the predicted class index
33
+ _, predicted_idx = torch.max(output, 1)
34
+
35
+ # Get the predicted label
36
+ predicted_label = labels[predicted_idx.item()]
37
+
38
+ return predicted_label
39
+
40
+ # Gradio UI components
41
+ image_input = gr.Image(preprocessing_fn=lambda img: Image.open(img.name))
42
+ output_label = gr.Textbox()
43
+
44
+ # Gradio interface
45
+ iface = gr.Interface(
46
+ fn=classify_image,
47
+ inputs=image_input,
48
+ outputs=output_label,
49
+ live=True,
50
+ capture_session=True
51
+ )
52
+
53
+ # Launch the Gradio app
54
  iface.launch()
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- gradio
 
 
 
 
1
+ gradio
2
+ pillow
3
+ torch
4
+ torchvision