Spaces:
Build error
Build error
| import torch | |
| from transformers import AutoModel | |
| import torch.nn as nn | |
| from PIL import Image | |
| import numpy as np | |
| import streamlit as st | |
| # Set the device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load the trained model from the Hugging Face Hub | |
| model = AutoModel.from_pretrained('dhhd255/EfficientNet_ParkinsonsPred') | |
| # Move the model to the device | |
| model = model.to(device) | |
| # Add custom CSS to use the Inter font, define custom classes for healthy and parkinsons results, increase the font size, make the text bold, and define the footer styles | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Inter&display=swap'); | |
| body { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .result { | |
| font-size: 24px; | |
| font-weight: bold; | |
| } | |
| .healthy { | |
| color: #007E3F; | |
| } | |
| .parkinsons { | |
| color: #C30000; | |
| } | |
| .social-links { | |
| display: flex; | |
| text-decoration:none; | |
| justify-content: center; | |
| } | |
| .social-links a { | |
| text-decoration:none; | |
| padding: 0 10px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("Parkinson's Disease Prediction") | |
| uploaded_file = st.file_uploader("Upload your :blue[Spiral] drawing here", type=["png", "jpg", "jpeg"]) | |
| st.empty() | |
| if uploaded_file is not None: | |
| col1, col2 = st.columns(2) | |
| # Load and resize the image | |
| image_size = (224, 224) | |
| new_image = Image.open(uploaded_file).convert('RGB').resize(image_size) | |
| col1.image(new_image, width=255) | |
| new_image = np.array(new_image) | |
| new_image = torch.from_numpy(new_image).transpose(0, 2).float().unsqueeze(0) | |
| # Move the data to the device | |
| new_image = new_image.to(device) | |
| # Make predictions using the trained model | |
| with torch.no_grad(): | |
| predictions = model(new_image) | |
| logits = predictions.last_hidden_state | |
| logits = logits.view(logits.shape[0], -1) | |
| num_classes=2 | |
| feature_reducer = nn.Linear(logits.shape[1], num_classes) | |
| logits = logits.to(device) | |
| feature_reducer = feature_reducer.to(device) | |
| logits = feature_reducer(logits) | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| confidence = torch.softmax(logits, dim=1)[0][predicted_class].item() | |
| if(predicted_class == 0): | |
| col2.markdown('<span class="result parkinsons">Predicted class: Parkinson\'s</span>', unsafe_allow_html=True) | |
| col2.caption(f'{confidence*100:.0f}% sure') | |
| else: | |
| col2.markdown('<span class="result healthy">Predicted class: Healthy</span>', unsafe_allow_html=True) | |
| col2.caption(f'{confidence*100:.0f}% sure') | |
| uploaded_file = st.file_uploader("Upload your :blue[Wave] drawing here", type=["png", "jpg", "jpeg"]) | |
| st.divider() | |
| st.markdown(""" | |
| <div class="social-links"> | |
| <a href="https://twitter.com/your_twitter_handle">Twitter</a> | |
| <a href="https://facebook.com/your_facebook_page">Facebook</a> | |
| <a href="https://instagram.com/your_instagram_handle">Instagram</a> | |
| </div> | |
| """, unsafe_allow_html=True) | |