File size: 4,007 Bytes
c9b451b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import cv2
import torch
import numpy as np
from pathlib import Path
import os
import shutil
import tempfile
from single_seg import load_unet_model, process_tile

# Load models once during initialization
model_unet_path = './weights/unet_epoch_24.pth'
yolo_model = torch.hub.load('ultralytics/yolov5', 'custom', path='./weights/best.pt')
unet_model = load_unet_model(model_path=model_unet_path) if Path(model_unet_path).exists() else None

def process_image(image_input, confidence_threshold=0.5):
    temp_dir = Path(tempfile.mkdtemp())
    output_base = temp_dir / "predictions" / "output"
    output_base.mkdir(parents=True, exist_ok=True)
    
    image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
    temp_img_path = temp_dir / "input.jpg"
    cv2.imwrite(str(temp_img_path), image)
    
    results = yolo_model(str(temp_img_path))
    
    tiles_dir = output_base / "tiles"
    seg_dir = output_base / "segment"
    tiles_dir.mkdir(exist_ok=True)
    seg_dir.mkdir(exist_ok=True)
    
    tile_size = 256
    output_image = image.copy()
    detection_info = ""
    
    for idx, detection in enumerate(results.xyxy[0]):
        x1, y1, x2, y2, conf, class_id = detection.tolist()
        
        if conf < confidence_threshold:
            continue
            
        x1, y1, x2, y2 = map(int, (x1, y1, x2, y2))
        
        width = x2 - x1
        height = y2 - y1
        x_new_1 = x1 - (tile_size - width) // 2
        x_new_2 = x2 + (tile_size - width) // 2
        y_new_1 = y1 - (tile_size - height) // 2
        y_new_2 = y2 + (tile_size - height) // 2
        
        x_new_1, y_new_1 = max(0, x_new_1), max(0, y_new_1)
        x_new_2 = min(image.shape[1], x_new_2)
        y_new_2 = min(image.shape[0], y_new_2)
        
        tile = image[y_new_1:y_new_2, x_new_1:x_new_2]
        tile_filename = f"tile_{idx}.jpg"
        tile_path = tiles_dir / tile_filename
        cv2.imwrite(str(tile_path), tile)
        
        seg_mask_path = seg_dir / tile_filename
        process_tile(tile_path, unet_model, save_dir=seg_mask_path)
        
        cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
        label = f"Mole {idx+1}: {conf:.2f}"
        cv2.putText(output_image, label, (x1, y1-10), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        
        detection_info += f"Mole {idx+1}: Confidence {conf:.2f}, "\
                         f"Location ({x1}, {y1}) to ({x2}, {y2})\n"
    
    output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
    
    info = f"Found {detection_info.count('Mole')} potential moles:\n\n" + detection_info
    
    shutil.rmtree(temp_dir)
    
    return output_image, info

def create_interface():
    with gr.Blocks(title="Skin Mole Detection App", theme=gr.themes.Soft()) as demo:
        gr.Markdown("""
        # Skin Mole Detection App
        Upload an image to detect skin moles using advanced AI models. 
        Features YOLOv5 detection.
        """)
        
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(type="numpy", label="Upload Image")
                conf_slider = gr.Slider(0, 1, value=0.5, step=0.05, 
                                      label="Confidence Threshold")
                submit_btn = gr.Button("Analyze Image", variant="primary")
            
            with gr.Column():
                output_image = gr.Image(type="numpy", label="Detected Moles")
        
        with gr.Row():
            info_text = gr.Markdown(label="Detection Info")

        def on_submit(image, conf_threshold):
            if image is None:
                return None, "Please upload an image!"
            return process_image(image, conf_threshold)

        submit_btn.click(fn=on_submit, inputs=[image_input, conf_slider], outputs=[output_image, info_text])
    
    return demo

# Launch for Paperspace Deployment
interface = create_interface()
interface.launch(server_name="0.0.0.0", server_port=7860, share=False)