Pengi5659 commited on
Commit
1d7ddcf
·
verified ·
1 Parent(s): d76b28c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -20
app.py CHANGED
@@ -1,23 +1,25 @@
1
- import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
 
 
 
3
 
4
- import tensorflow as tf
5
- import cv2
6
- import imghdr
7
- import os
8
 
9
- # Clean image files
10
- data_dir = 'data'
11
- image_exts = ['jpeg', 'jpg', 'bmp', 'png']
 
12
 
13
- for image_class in os.listdir(data_dir):
14
- for image in os.listdir(os.path.join(data_dir, image_class)):
15
- image_path = os.path.join(data_dir, image_class, image)
16
- try:
17
- img = cv2.imread(image_path)
18
- tip = imghdr.what(image_path)
19
- if tip not in image_exts:
20
- print('Removing:', image_path)
21
- os.remove(image_path)
22
- except Exception as e:
23
- print('Issue with image:', image_path)
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torchvision.models as models
5
+ from PIL import Image
6
 
7
+ # Load model (ensure it's uploaded to the Space)
8
+ model = models.resnet18(pretrained=True)
9
+ model.fc = torch.nn.Linear(model.fc.in_features, 2) # Adjust for your classes
 
10
 
11
+ transform = transforms.Compose([
12
+ transforms.Resize((224, 224)),
13
+ transforms.ToTensor()
14
+ ])
15
 
16
+ # Define the function to classify images
17
+ def classify_image(image):
18
+ image = transform(image).unsqueeze(0)
19
+ output = model(image)
20
+ _, predicted = torch.max(output, 1)
21
+ return "Class_A" if predicted.item() == 0 else "Class_B"
22
+
23
+ # Set up Gradio interface
24
+ iface = gr.Interface(fn=classify_image, inputs="webcam", outputs="text")
25
+ iface.launch()