Oill_split / model /unet.py
Utkarshres32's picture
Initial commit: AI-powered Oil Spill Detection and Monitoring System
7a5bb5d
import tensorflow as tf
from tensorflow.keras import layers, models
def conv_block(input_tensor, num_filters):
x = layers.Conv2D(num_filters, (3, 3), padding="same")(input_tensor)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(num_filters, (3, 3), padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
return x
def encoder_block(input_tensor, num_filters):
x = conv_block(input_tensor, num_filters)
p = layers.MaxPooling2D((2, 2))(x)
return x, p
def decoder_block(input_tensor, concat_tensor, num_filters):
x = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding="same")(input_tensor)
x = layers.concatenate([x, concat_tensor])
x = conv_block(x, num_filters)
return x
def build_unet(input_shape=(256, 256, 3)):
"""Builds a U-Net architecture for image segmentation."""
inputs = layers.Input(shape=input_shape)
# Encoder
e1, p1 = encoder_block(inputs, 64)
e2, p2 = encoder_block(p1, 128)
e3, p3 = encoder_block(p2, 256)
e4, p4 = encoder_block(p3, 512)
# Bridge
b = conv_block(p4, 1024)
# Decoder
d1 = decoder_block(b, e4, 512)
d2 = decoder_block(d1, e3, 256)
d3 = decoder_block(d2, e2, 128)
d4 = decoder_block(d3, e1, 64)
# Output (1 class for Oil Spill vs Background)
outputs = layers.Conv2D(1, (1, 1), padding="same", activation="sigmoid")(d4)
model = models.Model(inputs, outputs, name="U-Net")
return model
if __name__ == "__main__":
model = build_unet()
model.summary()