Pengi5659 commited on
Commit
46d545a
·
verified ·
1 Parent(s): 56c3b76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -17
app.py CHANGED
@@ -3,28 +3,19 @@ import torch
3
  import torchvision.transforms as transforms
4
  import torchvision.models as models
5
  from PIL import Image
6
- import os
7
- from torch.utils.data import DataLoader, Dataset
8
- from torchvision.datasets import ImageFolder
9
 
10
- # Define transformations for images
 
 
 
 
 
11
  transform = transforms.Compose([
12
  transforms.Resize((224, 224)),
13
  transforms.ToTensor()
14
  ])
15
 
16
- # Load dataset from your image folders
17
- dataset = ImageFolder(root="posture_samples", transform=transform) # Ensure "posture_samples" contains two subfolders: "Good Posture-samples" and "Bad Posture-samples"
18
-
19
- # Create a DataLoader
20
- dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
21
-
22
- # Load pre-trained ResNet18 model
23
- model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
24
- model.fc = torch.nn.Linear(model.fc.in_features, 2) # Adjust output for two classes
25
- model.eval() # Set to evaluation mode
26
-
27
- # Define function to classify an image
28
  def classify_image(image):
29
  image = transform(image).unsqueeze(0)
30
  output = model(image)
@@ -32,6 +23,7 @@ def classify_image(image):
32
  return "Good Posture" if predicted.item() == 0 else "Bad Posture"
33
 
34
  # Set up Gradio interface
35
- iface = gr.Interface(fn=classify_image, inputs=gr.Image(source="webcam", type="pil"), outputs="text")
36
  iface.launch()
37
 
 
 
3
  import torchvision.transforms as transforms
4
  import torchvision.models as models
5
  from PIL import Image
 
 
 
6
 
7
+ # Load the ResNet18 model with pre-trained weights
8
+ model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
9
+ model.fc = torch.nn.Linear(model.fc.in_features, 2) # Adjust for two classes
10
+ model.eval() # Set to evaluation mode
11
+
12
+ # Define image transformation
13
  transform = transforms.Compose([
14
  transforms.Resize((224, 224)),
15
  transforms.ToTensor()
16
  ])
17
 
18
+ # Function to classify posture images
 
 
 
 
 
 
 
 
 
 
 
19
  def classify_image(image):
20
  image = transform(image).unsqueeze(0)
21
  output = model(image)
 
23
  return "Good Posture" if predicted.item() == 0 else "Bad Posture"
24
 
25
  # Set up Gradio interface
26
+ iface = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil", tool="camera"), outputs="text")
27
  iface.launch()
28
 
29
+