darthvader2603 commited on
Commit
bf3e65e
·
verified ·
1 Parent(s): f112872

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +86 -0
  2. iris_segmentation_model.pth +3 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models as models
4
+ import torchvision.transforms as T
5
+ import gradio as gr
6
+ import numpy as np
7
+ import cv2
8
+ from PIL import Image
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+
11
+ # === Model ===
12
+ class SimpleUNet(nn.Module):
13
+ def __init__(self):
14
+ super(SimpleUNet, self).__init__()
15
+ base_model = models.mobilenet_v2(pretrained=True).features
16
+ self.encoder = base_model
17
+ self.decoder = nn.Sequential(
18
+ nn.ConvTranspose2d(1280, 512, kernel_size=2, stride=2),
19
+ nn.ReLU(),
20
+ nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
21
+ nn.ReLU(),
22
+ nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
23
+ nn.ReLU(),
24
+ nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
25
+ nn.ReLU(),
26
+ nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2),
27
+ nn.Sigmoid()
28
+ )
29
+
30
+ def forward(self, x):
31
+ x = self.encoder(x)
32
+ x = self.decoder(x)
33
+ return x
34
+
35
+ # === Load Model ===
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ model = SimpleUNet().to(device)
38
+ model.load_state_dict(torch.load("iris_segmentation_model.pth", map_location=device))
39
+ model.eval()
40
+
41
+ # === Transform ===
42
+ transform = T.Compose([
43
+ T.Resize((224, 224)),
44
+ T.ToTensor(),
45
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
46
+ ])
47
+
48
+ # === Main function ===
49
+ def segment_iris(image):
50
+ frame = np.array(image)
51
+ gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
52
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
53
+ enhanced = clahe.apply(gray)
54
+ rgb_enhanced = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)
55
+ pil_img = Image.fromarray(rgb_enhanced)
56
+
57
+ input_tensor = transform(pil_img).unsqueeze(0).to(device)
58
+
59
+ with torch.no_grad():
60
+ pred = model(input_tensor)
61
+ pred_mask = pred.squeeze().cpu().numpy()
62
+
63
+ binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
64
+ binary_mask = cv2.resize(binary_mask, (frame.shape[1], frame.shape[0]))
65
+ color_mask = cv2.applyColorMap(binary_mask, cv2.COLORMAP_JET)
66
+ blended = cv2.addWeighted(frame, 0.7, color_mask, 0.3, 0)
67
+
68
+ return Image.fromarray(blended)
69
+
70
+ # === Gradio UI ===
71
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
72
+ gr.Markdown("# 👁️ Iris Segmentation App")
73
+ gr.Markdown("Upload an eye image below. The model will perform iris segmentation and return a blended visualization. Webcam option only works when deployed locally.")
74
+
75
+ with gr.Row():
76
+ img_input = gr.Image(type="pil", label="📤 Upload Eye Image")
77
+ img_output = gr.Image(label="🎯 Segmentation Output")
78
+
79
+ img_input.change(fn=segment_iris, inputs=img_input, outputs=img_output)
80
+
81
+ gr.Markdown("⚠️ Webcam support is available only in local mode.")
82
+
83
+ # Placeholder for webcam mode
84
+ gr.Image(source="webcam", streaming=True, label="(Webcam – Local Only)", visible=False)
85
+
86
+ demo.launch()
iris_segmentation_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:769ed1ea4348ed4c273f477557fc46355575e1b6a92f88c40e45fff8c3600a41
3
+ size 22390414
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ opencv-python
5
+ scikit-learn
6
+ pillow