kanneboinakumar's picture
Update app.py
3447370 verified
raw
history blame
3.12 kB
import streamlit as st
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
from PIL import Image
st.markdown(
"""
<style>
/* Set background image for the entire app */
.stApp {
background: url('https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSmgSUM3cbGaWX4tPdO2TGEX0x52TkjyuhfaA&sahttps://wp.technologyreview.com/wp-content/uploads/2023/04/brain-decode2.jpeg') no-repeat center center fixed;
background-size: cover;
}
.stApp h1 {
background-color: rgba(0, 0, 128, 0.7);
color: #ffffff;
padding: 10px;
border-radius: 5px;
font-size: 2.2em;
text-align: center;
white-space: nowrap; /* Prevents line break */
overflow: hidden;
text-overflow: ellipsis;
max-width: 100%;
margin: 0 auto;
}
/* Style for the button */
.stButton>button {
background-color: #4CAF50; /* Green */
color: white;
font-size: 1.2em;
border-radius: 10px;
padding: 10px 24px;
border: none;
}
/* Center the button */
.stButton {
display: flex;
justify-content: center;
}
/* Style for the output container */
.output-container {
background-color: lightpink;
color: black;
font-size: 1.5em;
padding: 15px;
border-radius: 10px;
margin-top: 20px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
width: 200%;
margin-left: auto;
margin-right: auto;
text-align: center;
}
</style>
""",
unsafe_allow_html=True
)
# Title
st.title("Brain Tumor Classification")
# Class names
class_names = ['glioma_tumor', 'meningioma_tumor', 'no_tumor', 'pituitary_tumor']
# Load pre-trained ResNet18 model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
num_of_classes = len(class_names)
num_of_features = model.fc.in_features
model.fc = nn.Linear(num_of_features, num_of_classes)
# Load trained model weights
model.load_state_dict(torch.load('resnet18_model (1).pth', map_location=torch.device('cpu')))
model.eval()
# Image upload
uploaded_img = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_img is not None:
# Display uploaded image in a smaller size
image = Image.open(uploaded_img)
st.image(image, caption="Uploaded Image", width=200) # Set width to reduce image size
# Image transformations
sample_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.1776, 0.1776, 0.1776], std=[0.1735, 0.1735, 0.1735])
])
# Apply transformations
transformed_img = sample_transform(image).unsqueeze(0)
# Model inference
with torch.no_grad():
pred = model(transformed_img).argmax(dim=1).item()
# Stylish output box
st.markdown(
f"""
<div class="output-container">
🧠 <strong>Predicted Class:</strong> {class_names[pred]}
</div>
""",
unsafe_allow_html=True
)