AndaiMD commited on
Commit
259d78e
·
1 Parent(s): 0ad52ef

unet model

Browse files
app.py CHANGED
@@ -8,13 +8,13 @@ import gradio as gr
8
  import torch
9
  import numpy as np
10
  from torchvision import transforms
11
- from models.unet import UNet
12
  from PIL import Image
13
  import matplotlib.pyplot as plt
14
  import io
15
 
16
  # Load model
17
- model_path = "checkpoints/unet_epoch10.pth"
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  model = UNet(in_channels=1, out_channels=3).to(device)
 
8
  import torch
9
  import numpy as np
10
  from torchvision import transforms
11
+ from model.unet import UNet
12
  from PIL import Image
13
  import matplotlib.pyplot as plt
14
  import io
15
 
16
  # Load model
17
+ model_path = "model/unet_epoch20.pth"
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  model = UNet(in_channels=1, out_channels=3).to(device)
model/__pycache__/unet.cpython-313.pyc ADDED
Binary file (4.84 kB). View file
 
model/unet.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class DoubleConv(nn.Module):
6
+ """
7
+ This is the core building block of the U-Net architecture.
8
+ Use consecutive convolutional layers
9
+ Each followed by batch normalization and ReLU activation
10
+ """
11
+ def __init__(self, in_channels, out_channels):
12
+ super().__init__()
13
+ """
14
+ nn.Conv2d:
15
+ Applies a 2D convolution filter (kernel size 3×3)
16
+ padding=1 ensures the output spatial size stays the same
17
+ First conv changes input channels → output channels
18
+ Second conv keeps it at out_channels
19
+
20
+ nn.BatchNorm2d
21
+ Normalizes activations across the batch and channels
22
+ Helps stabilize and speed up training
23
+ Reduces internal covariate shift
24
+ """
25
+ self.double_conv = nn.Sequential(
26
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
27
+ nn.BatchNorm2d(out_channels),
28
+ nn.ReLU(inplace=True),
29
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU(inplace=True)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.double_conv(x)
36
+
37
+ class UNet(nn.Module):
38
+ def __init__(self, in_channels=1, out_channels=3):
39
+ super().__init__()
40
+
41
+ # Encoder
42
+ self.down1 = DoubleConv(in_channels, 64)
43
+ self.pool1 = nn.MaxPool2d(2)
44
+
45
+ self.down2 = DoubleConv(64, 128)
46
+ self.pool2 = nn.MaxPool2d(2)
47
+
48
+ self.down3 = DoubleConv(128, 256)
49
+ self.pool3 = nn.MaxPool2d(2)
50
+
51
+ self.down4 = DoubleConv(256, 512)
52
+ self.pool4 = nn.MaxPool2d(2)
53
+
54
+ # Bottleneck
55
+ self.bottleneck = DoubleConv(512, 1024)
56
+
57
+ # Decoder
58
+ self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
59
+ self.dec4 = DoubleConv(1024, 512)
60
+
61
+ self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
62
+ self.dec3 = DoubleConv(512, 256)
63
+
64
+ self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
65
+ self.dec2 = DoubleConv(256, 128)
66
+
67
+ self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
68
+ self.dec1 = DoubleConv(128, 64)
69
+
70
+ # Final output layer
71
+ self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
72
+
73
+ def forward(self, x):
74
+ # Encoder
75
+ d1 = self.down1(x)
76
+ d2 = self.down2(self.pool1(d1))
77
+ d3 = self.down3(self.pool2(d2))
78
+ d4 = self.down4(self.pool3(d3))
79
+
80
+ # Bottleneck
81
+ bn = self.bottleneck(self.pool4(d4))
82
+
83
+ # Decoder
84
+ up4 = self.up4(bn)
85
+ dec4 = self.dec4(torch.cat([up4, d4], dim=1))
86
+
87
+ up3 = self.up3(dec4)
88
+ dec3 = self.dec3(torch.cat([up3, d3], dim=1))
89
+
90
+ up2 = self.up2(dec3)
91
+ dec2 = self.dec2(torch.cat([up2, d2], dim=1))
92
+
93
+ up1 = self.up1(dec2)
94
+ dec1 = self.dec1(torch.cat([up1, d1], dim=1))
95
+
96
+ # Output
97
+ return self.out_conv(dec1)