Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import timm | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Optional: Turn off file watchers in HF Spaces to avoid torch-related warnings | |
| os.environ["STREAMLIT_WATCHER_TYPE"] = "none" | |
| # Define the model class | |
| class MobileViTSegmentation(nn.Module): | |
| def __init__(self, encoder_name='mobilevit_s', pretrained=False): | |
| super().__init__() | |
| self.backbone = timm.create_model(encoder_name, features_only=True, pretrained=pretrained) | |
| self.encoder_channels = self.backbone.feature_info.channels() | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(self.encoder_channels[-1], 128, kernel_size=3, padding=1), | |
| nn.Upsample(scale_factor=2, mode='bilinear'), | |
| nn.Conv2d(128, 64, kernel_size=3, padding=1), | |
| nn.Upsample(scale_factor=2, mode='bilinear'), | |
| nn.Conv2d(64, 32, kernel_size=3, padding=1), | |
| nn.Upsample(scale_factor=2, mode='bilinear'), | |
| nn.Conv2d(32, 1, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| feats = self.backbone(x) | |
| out = self.decoder(feats[-1]) | |
| out = nn.functional.interpolate(out, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False) | |
| return out | |
| # Load model function with spinner and error handling | |
| def load_model(): | |
| try: | |
| with st.spinner("Loading model..."): | |
| model = MobileViTSegmentation() | |
| model.load_state_dict(torch.load("mobilevit_teeth_segmentation.pth", map_location="cpu")) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| st.error(f"❌ Failed to load model: {e}") | |
| st.stop() | |
| # Inference function | |
| def predict_mask(image, model, threshold=0.7): | |
| try: | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| pred = model(img_tensor) | |
| pred_mask = pred.squeeze().numpy() | |
| pred_mask = (pred_mask > threshold).astype(np.uint8) | |
| return pred_mask | |
| except Exception as e: | |
| st.error(f"❌ Prediction failed: {e}") | |
| return None | |
| # Overlay mask on image | |
| def overlay_mask(image, mask, color=(0, 0, 255), alpha=0.4): | |
| try: | |
| image_np = np.array(image.convert("RGB")) | |
| mask_resized = cv2.resize(mask, (image_np.shape[1], image_np.shape[0])) | |
| color_mask = np.zeros_like(image_np) | |
| color_mask[:, :] = color | |
| overlay = np.where(mask_resized[..., None] == 1, color_mask, 0) | |
| blended = cv2.addWeighted(image_np, 1 - alpha, overlay, alpha, 0) | |
| return blended | |
| except Exception as e: | |
| st.error(f"❌ Mask overlay failed: {e}") | |
| return np.array(image) | |
| # Streamlit UI | |
| st.set_page_config(page_title="Tooth Segmentation", layout="wide") | |
| st.title("🦷 Tooth Segmentation from Mouth Images") | |
| st.markdown("Upload a **face or mouth image**, and this app will overlay the **predicted tooth segmentation mask**.") | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file: | |
| try: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| model = load_model() | |
| pred_mask = predict_mask(image, model) | |
| if pred_mask is not None: | |
| overlayed_img = overlay_mask(image, pred_mask, color=(0, 0, 255), alpha=0.4) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Original Image", use_container_width=True) | |
| with col2: | |
| st.image(overlayed_img, caption="Tooth Mask Overlay", use_container_width=True) | |
| except Exception as e: | |
| st.error(f"❌ Error processing image: {e}") | |