balakrish181's picture
add files
c9b451b
import gradio as gr
import cv2
import torch
import numpy as np
from pathlib import Path
import os
import shutil
import tempfile
from single_seg import load_unet_model, process_tile
# Load models once during initialization
model_unet_path = './weights/unet_epoch_24.pth'
yolo_model = torch.hub.load('ultralytics/yolov5', 'custom', path='./weights/best.pt')
unet_model = load_unet_model(model_path=model_unet_path) if Path(model_unet_path).exists() else None
def process_image(image_input, confidence_threshold=0.5):
temp_dir = Path(tempfile.mkdtemp())
output_base = temp_dir / "predictions" / "output"
output_base.mkdir(parents=True, exist_ok=True)
image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
temp_img_path = temp_dir / "input.jpg"
cv2.imwrite(str(temp_img_path), image)
results = yolo_model(str(temp_img_path))
tiles_dir = output_base / "tiles"
seg_dir = output_base / "segment"
tiles_dir.mkdir(exist_ok=True)
seg_dir.mkdir(exist_ok=True)
tile_size = 256
output_image = image.copy()
detection_info = ""
for idx, detection in enumerate(results.xyxy[0]):
x1, y1, x2, y2, conf, class_id = detection.tolist()
if conf < confidence_threshold:
continue
x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
width = x2 - x1
height = y2 - y1
x_new_1 = x1 - (tile_size - width) // 2
x_new_2 = x2 + (tile_size - width) // 2
y_new_1 = y1 - (tile_size - height) // 2
y_new_2 = y2 + (tile_size - height) // 2
x_new_1, y_new_1 = max(0, x_new_1), max(0, y_new_1)
x_new_2 = min(image.shape[1], x_new_2)
y_new_2 = min(image.shape[0], y_new_2)
tile = image[y_new_1:y_new_2, x_new_1:x_new_2]
tile_filename = f"tile_{idx}.jpg"
tile_path = tiles_dir / tile_filename
cv2.imwrite(str(tile_path), tile)
seg_mask_path = seg_dir / tile_filename
process_tile(tile_path, unet_model, save_dir=seg_mask_path)
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"Mole {idx+1}: {conf:.2f}"
cv2.putText(output_image, label, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
detection_info += f"Mole {idx+1}: Confidence {conf:.2f}, "\
f"Location ({x1}, {y1}) to ({x2}, {y2})\n"
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
info = f"Found {detection_info.count('Mole')} potential moles:\n\n" + detection_info
shutil.rmtree(temp_dir)
return output_image, info
def create_interface():
with gr.Blocks(title="Skin Mole Detection App", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Skin Mole Detection App
Upload an image to detect skin moles using advanced AI models.
Features YOLOv5 detection.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Upload Image")
conf_slider = gr.Slider(0, 1, value=0.5, step=0.05,
label="Confidence Threshold")
submit_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column():
output_image = gr.Image(type="numpy", label="Detected Moles")
with gr.Row():
info_text = gr.Markdown(label="Detection Info")
def on_submit(image, conf_threshold):
if image is None:
return None, "Please upload an image!"
return process_image(image, conf_threshold)
submit_btn.click(fn=on_submit, inputs=[image_input, conf_slider], outputs=[output_image, info_text])
return demo
# Launch for Paperspace Deployment
interface = create_interface()
interface.launch(server_name="0.0.0.0", server_port=7860, share=False)