JaredBailey commited on
Commit
45bf25d
·
verified ·
1 Parent(s): bf36852

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -4,4 +4,50 @@ from torchvision import transforms
4
  from PIL import Image
5
 
6
 
7
- st.write("Hi Jay and Suneel")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
 
6
 
7
+ st.write("Hi Jay and Suneel")
8
+
9
+ # Load the image and apply transformations
10
+
11
+ def predict_image(image_path, model):
12
+
13
+ image = Image.open(image_path).convert('RGB')
14
+ transform = transforms.Compose([
15
+ transforms.Resize((224, 224)),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
18
+ std=[0.229, 0.224, 0.225])
19
+ ])
20
+ input_image = transform(image).unsqueeze(0) # Add batch dimension
21
+
22
+ # Move input tensor to the device (GPU if available)
23
+ input_image = input_image.to(device)
24
+
25
+ # Perform inference
26
+ model.eval()
27
+ with torch.no_grad():
28
+ output = model(input_image)
29
+
30
+ # Get predicted class probabilities and class index
31
+ probabilities = torch.softmax(output, dim=1)[0]
32
+ predicted_class_index = torch.argmax(probabilities).item()
33
+
34
+ # Map class index to class label
35
+ class_labels = dataset.classes
36
+ predicted_class_label = class_labels[predicted_class_index]
37
+
38
+ return predicted_class_label
39
+ # print("Class probabilities:")
40
+ # for i, prob in enumerate(probabilities):
41
+ # print(f"{class_labels[i]}: {prob:.4f}")
42
+
43
+
44
+ model_loaded = torchvision.models.resnet18(pretrained=False) # Initialize ResNet18 without pretraining
45
+ model_loaded.fc = torch.nn.Linear(model_loaded.fc.in_features, len(dataset.classes)) # Modify the fully connected layer
46
+ model_loaded = model_loaded.to(device) # Move the model to the appropriate device (GPU or CPU)
47
+
48
+ # Load the saved state dictionary into the model
49
+ model_path = 'resnet18_custom_model.pth'
50
+ model_loaded.load_state_dict(torch.load(model_path))
51
+
52
+ # Set the model to evaluation mode
53
+ model_loaded.eval()