my-model / app.py
faranbutt789's picture
Update app.py
12b9e29 verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
from unet import ImprovedUNet
from huggingface_hub import hf_hub_download
import cv2
# Load trained model weights from Hugging Face Hub
try:
weights_path = hf_hub_download(
repo_id="faranbutt789/my-model", # Updated to match your repo
filename="unet_weights.pth" # Updated filename as uploaded
)
except Exception as e:
print(f"Error downloading weights: {e}")
# Fallback to local file if available
weights_path = "unet_weights_v2.pth"
# Initialize and load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedUNet()
try:
# Load the state dict
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
print("Model weights loaded successfully!")
except Exception as e:
print(f"Error loading model weights: {e}")
print("Using randomly initialized model (for testing)")
model.to(device)
model.eval()
# Preprocessing: same as training
IMG_HEIGHT, IMG_WIDTH = 128, 128
transform = transforms.Compose([
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
transforms.ToTensor(),
# Add normalization if you used it during training
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(image):
if image is None:
return None
try:
# Store original size
orig_w, orig_h = image.size
# Convert to RGB if not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Apply preprocessing
img_tensor = transform(image).unsqueeze(0).to(device) # (1,3,128,128)
# Inference
with torch.no_grad():
pred = model(img_tensor)
# Post-process the prediction
mask = pred.squeeze(0).squeeze(0).cpu().numpy() # Remove batch and channel dims
# Convert to 0-255 range
mask = (mask * 255).astype(np.uint8)
# Resize back to original size
mask_resized = cv2.resize(mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
# Convert to PIL Image
mask_img = Image.fromarray(mask_resized, mode='L')
# Create a colored overlay for better visualization
# Convert grayscale mask to RGB
mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB)
# Create colored mask (red for cracks)
colored_mask = np.zeros_like(mask_rgb)
colored_mask[:, :, 0] = mask_resized # Red channel for cracks
# Convert original image to numpy for overlay
orig_np = np.array(image.resize((orig_w, orig_h)))
# Create overlay
alpha = 0.4 # Transparency
overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0)
overlay_img = Image.fromarray(overlay)
return mask_img, overlay_img
except Exception as e:
print(f"Error in prediction: {e}")
# Return a blank image in case of error
blank = Image.new('L', (orig_w, orig_h), 0)
return blank, blank
def predict_with_threshold(image, threshold):
if image is None:
return None, None
try:
orig_w, orig_h = image.size
if image.mode != 'RGB':
image = image.convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
pred = model(img_tensor)
mask = pred.squeeze(0).squeeze(0).cpu().numpy()
# Apply threshold
mask_binary = (mask > threshold).astype(np.uint8) * 255
# Resize back to original size
mask_resized = cv2.resize(mask_binary, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
mask_img = Image.fromarray(mask_resized, mode='L')
# Create colored overlay
mask_rgb = cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2RGB)
colored_mask = np.zeros_like(mask_rgb)
colored_mask[:, :, 0] = mask_resized
orig_np = np.array(image.resize((orig_w, orig_h)))
alpha = 0.4
overlay = cv2.addWeighted(orig_np, 1-alpha, colored_mask, alpha, 0)
overlay_img = Image.fromarray(overlay)
return mask_img, overlay_img
except Exception as e:
print(f"Error in prediction with threshold: {e}")
blank = Image.new('L', (orig_w, orig_h), 0)
return blank, blank
# Create Gradio interface with multiple tabs
with gr.Blocks(title="UNet Crack Segmentation", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸ” Concrete Crack Segmentation with UNet
Upload an image of a concrete surface to detect and segment cracks using a trained UNet model.
**Features:**
- Advanced UNet architecture with batch normalization and dropout
- Optimized for highly imbalanced crack detection
- Interactive threshold adjustment
- Colored overlay visualization
"""
)
with gr.Tabs():
with gr.TabItem("Basic Prediction"):
with gr.Row():
with gr.Column():
input_image1 = gr.Image(
type="pil",
label="Upload Concrete Image",
height=400
)
predict_btn1 = gr.Button("πŸ” Detect Cracks", variant="primary", size="lg")
with gr.Column():
output_mask1 = gr.Image(
label="Crack Mask",
height=400
)
output_overlay1 = gr.Image(
label="Overlay Visualization",
height=400
)
predict_btn1.click(
predict,
inputs=[input_image1],
outputs=[output_mask1, output_overlay1]
)
with gr.TabItem("Advanced Prediction"):
with gr.Row():
with gr.Column():
input_image2 = gr.Image(
type="pil",
label="Upload Concrete Image",
height=400
)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.1,
label="Detection Threshold"
)
predict_btn2 = gr.Button("πŸ” Detect Cracks", variant="primary", size="lg")
with gr.Column():
output_mask2 = gr.Image(
label="Crack Mask",
height=400
)
output_overlay2 = gr.Image(
label="Overlay Visualization",
height=400
)
predict_btn2.click(
predict_with_threshold,
inputs=[input_image2, threshold_slider],
outputs=[output_mask2, output_overlay2]
)
gr.Markdown(
"""
### How to use:
1. **Upload** a concrete surface image
2. **Click** "Detect Cracks" to run the segmentation
3. **View** the results: white areas in the mask indicate detected cracks
4. **Adjust** the threshold in Advanced mode for fine-tuning sensitivity
### Model Information:
- **Architecture**: Improved UNet with BatchNorm and Dropout
- **Input Size**: Images are resized to 128x128 for processing
- **Output**: Binary segmentation mask highlighting crack regions
- **Training**: Optimized for imbalanced crack detection using advanced loss functions
### Tips for better results:
- Use high-contrast images where cracks are visible
- Ensure good lighting conditions
- Try adjusting the threshold if results seem too sensitive or not sensitive enough
"""
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)