resolverkatla commited on
Commit
2384571
·
verified ·
1 Parent(s): fc7ac8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -38
app.py CHANGED
@@ -1,38 +1,41 @@
1
- import streamlit as st
2
- from PIL import Image
3
- import torch
4
- from torchvision import transforms
5
- from models.cnn import CNNModel
6
- from utils.transforms import get_transforms
7
-
8
- @st.cache_resource
9
- def load_model(model_path='saved_models/cnn_model.pth'):
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- checkpoint = torch.load(model_path, map_location=device)
12
- class_names = checkpoint['class_names']
13
- model = CNNModel(num_classes=len(class_names))
14
- model.load_state_dict(checkpoint['model_state_dict'])
15
- model.to(device)
16
- model.eval()
17
- return model, class_names, device
18
-
19
- st.title("📸 Intel Image Classification")
20
- st.write("Upload an image to classify it into one of the image categories: buildings, forest, glacier, mountain, sea, or street.")
21
-
22
- model, class_names, device = load_model()
23
-
24
- uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
25
-
26
- if uploaded_file:
27
- image = Image.open(uploaded_file).convert("RGB")
28
- st.image(image, caption="Uploaded Image", use_container_width=True)
29
-
30
- transform = get_transforms(train=False)
31
- image_tensor = transform(image).unsqueeze(0).to(device)
32
-
33
- with torch.no_grad():
34
- output = model(image_tensor)
35
- predicted_idx = torch.argmax(output, 1).item()
36
- predicted_class = class_names[predicted_idx]
37
-
38
- st.success(f"Predicted class: {predicted_class}")
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+ from models.cnn import CNNModel
7
+ from utils.transforms import get_transforms
8
+
9
+ os.environ["STREAMLIT_ROOT"] = "/tmp/.streamlit"
10
+
11
+ @st.cache_resource
12
+ def load_model(model_path='saved_models/cnn_model.pth'):
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ checkpoint = torch.load(model_path, map_location=device)
15
+ class_names = checkpoint['class_names']
16
+ model = CNNModel(num_classes=len(class_names))
17
+ model.load_state_dict(checkpoint['model_state_dict'])
18
+ model.to(device)
19
+ model.eval()
20
+ return model, class_names, device
21
+
22
+ st.title("📸 Intel Image Classification")
23
+ st.write("Upload an image to classify it into one of the image categories: buildings, forest, glacier, mountain, sea, or street.")
24
+
25
+ model, class_names, device = load_model()
26
+
27
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
28
+
29
+ if uploaded_file:
30
+ image = Image.open(uploaded_file).convert("RGB")
31
+ st.image(image, caption="Uploaded Image", use_container_width=True)
32
+
33
+ transform = get_transforms(train=False)
34
+ image_tensor = transform(image).unsqueeze(0).to(device)
35
+
36
+ with torch.no_grad():
37
+ output = model(image_tensor)
38
+ predicted_idx = torch.argmax(output, 1).item()
39
+ predicted_class = class_names[predicted_idx]
40
+
41
+ st.success(f"Predicted class: {predicted_class}")