| import streamlit as st |
| import torch.nn as nn |
|
|
| import torch |
| from torchvision import models, transforms |
| from PIL import Image |
|
|
| CATEGORIES = ["AIHOLE", "BILLESHWAR_TEMPLE", "CHENNAKESHWARA_TEMPLE", "HAMPI_CHARIOT", "IBRAHIM_ROZA", "JAIN_BASADI", "KAMAL_BASTI", "KEDARESHWARA_TEMPLE", "KESHAVA_TEMPLE", "LOTUS_MAHAL"] |
| IMG_SIZE = 224 |
| |
| model = models.resnet50(pretrained=False) |
| num_features = model.fc.in_features |
| model.fc = nn.Linear(num_features, len(CATEGORIES)) |
| model.load_state_dict(torch.load("trained_model.pt", map_location=torch.device('cpu'))) |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| |
| def classify_image(image): |
| image = transform(image).unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(image) |
| _, predicted = torch.max(outputs.data, 1) |
|
|
| return predicted.item() |
|
|
| |
| def main(): |
| st.title("Temple Image Classification") |
|
|
| |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
|
| if uploaded_file is not None: |
| image = Image.open(uploaded_file) |
| st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
| |
| if st.button("Classify"): |
| prediction = classify_image(image) |
| st.write(f"Predicted Category: {CATEGORIES[prediction]}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|