pltnhan07's picture
upload
96794d4
import os
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import cv2
import numpy as np
from flask import jsonify, request
import functions_framework
# Set device (Cloud Functions generally use CPU, but GPU will be used if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###############################################################################
# Utility Functions for Image Preprocessing (Optional)
###############################################################################
def remove_hair_and_markings(img_bgr, kernel_size=17):
"""Remove thin hair and markings using morphological operations."""
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, kernel)
_, mask = cv2.threshold(blackhat, 10, 255, cv2.THRESH_BINARY)
inpainted = cv2.inpaint(img_bgr, mask, 3, cv2.INPAINT_TELEA)
return inpainted
def normalize_color(img_bgr):
"""Per-channel normalization: subtract mean and divide by std for each channel."""
img_float = img_bgr.astype(np.float32)
b, g, r = cv2.split(img_float)
for channel in (b, g, r):
mean_val = np.mean(channel)
std_val = np.std(channel) + 1e-8
channel[:] = (channel - mean_val) / std_val
normalized = cv2.merge([b, g, r])
return normalized
def mmwf_filter(gray_image, window_size=3, noise_variance=0.01):
"""Apply the Median–Modified Wiener Filter (MMWF) on a grayscale image."""
if gray_image.dtype != np.float32:
gray_image = gray_image.astype(np.float32)
pad_size = window_size // 2
padded = np.pad(gray_image, pad_size, mode='reflect')
filtered = np.zeros_like(gray_image, dtype=np.float32)
rows, cols = gray_image.shape
for i in range(rows):
for j in range(cols):
local_patch = padded[i:i+window_size, j:j+window_size]
mu_m = np.median(local_patch)
sigma_sq = local_patch.var()
a_val = gray_image[i, j]
if sigma_sq < 1e-12:
filtered[i, j] = mu_m
else:
filtered[i, j] = mu_m + ((sigma_sq - noise_variance) / sigma_sq) * (a_val - mu_m)
return filtered
def mmwf_filter_color(img_bgr, window_size=3, noise_variance=0.01):
"""Apply MMWF to each channel of a BGR image."""
if img_bgr.dtype != np.float32:
img_bgr = img_bgr.astype(np.float32)
b, g, r = cv2.split(img_bgr)
b_denoised = mmwf_filter(b, window_size, noise_variance)
g_denoised = mmwf_filter(g, window_size, noise_variance)
r_denoised = mmwf_filter(r, window_size, noise_variance)
denoised_bgr = cv2.merge([b_denoised, g_denoised, r_denoised])
return denoised_bgr
def preprocess_image(image):
"""
Optionally, apply preprocessing steps:
1. Convert from PIL RGB to NumPy BGR.
2. Remove hair/markings.
3. Normalize color.
4. Apply MMWF filter.
5. Convert back to PIL RGB.
"""
image_np = np.array(image)
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
image_bgr = remove_hair_and_markings(image_bgr, kernel_size=17)
image_bgr = normalize_color(image_bgr)
image_bgr = cv2.normalize(image_bgr, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX).astype(np.uint8)
# image_bgr = mmwf_filter_color(image_bgr, window_size=3, noise_variance=0.01)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
return Image.fromarray(image_rgb)
###############################################################################
# Helper Modules for the Classifier
###############################################################################
def create_classification_head(input_dim, num_classes):
return nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
class MetadataEncoder(nn.Module):
def __init__(self, input_size, output_size):
super(MetadataEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv1d(1, 16, kernel_size=3, padding=1),
nn.BatchNorm1d(16),
nn.ReLU(),
nn.Conv1d(16, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Flatten(),
nn.Linear(32 * input_size, output_size),
nn.ReLU()
)
def forward(self, x):
x = x.unsqueeze(1)
return self.encoder(x)
class GraphAttentionLayer(nn.Module):
def __init__(self, in_features, out_features, dropout=0.0, alpha=0.2):
super(GraphAttentionLayer, self).__init__()
self.W = nn.Linear(in_features, out_features, bias=False)
self.a = nn.Parameter(torch.empty(2 * out_features, 1))
nn.init.xavier_uniform_(self.W.weight.data, gain=1.414)
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(alpha)
self.dropout = nn.Dropout(dropout)
def forward(self, h, adj=None):
Wh = self.W(h)
batch_size, N, _ = Wh.size()
Wh_i = Wh.unsqueeze(2).repeat(1, 1, N, 1)
Wh_j = Wh.unsqueeze(1).repeat(1, N, 1, 1)
e = self.leakyrelu(torch.matmul(torch.cat([Wh_i, Wh_j], dim=-1), self.a).squeeze(-1))
attention = F.softmax(e, dim=-1)
attention = self.dropout(attention)
h_prime = torch.matmul(attention, Wh)
return h_prime, attention
class MultiHeadGraphAttentionLayer(nn.Module):
def __init__(self, in_features, out_features, num_heads=4, dropout=0.0, alpha=0.2):
super(MultiHeadGraphAttentionLayer, self).__init__()
self.num_heads = num_heads
self.heads = nn.ModuleList([
GraphAttentionLayer(in_features, out_features, dropout, alpha)
for _ in range(num_heads)
])
self.linear = nn.Linear(num_heads * out_features, out_features)
def forward(self, h):
head_outputs = [head(h)[0] for head in self.heads]
h_concat = torch.cat(head_outputs, dim=-1)
return self.linear(h_concat)
class EnhancedGraphFusion(nn.Module):
def __init__(self, in_features, out_features, num_heads=4, num_layers=2, dropout=0.0, alpha=0.2):
super(EnhancedGraphFusion, self).__init__()
self.global_init = nn.Parameter(torch.zeros(in_features))
self.layers = nn.ModuleList([
MultiHeadGraphAttentionLayer(in_features, in_features, num_heads, dropout, alpha)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(in_features)
self.proj = nn.Linear(in_features, out_features)
def forward(self, image_nodes, metadata_feat):
batch_size = image_nodes.size(0)
metadata_node = metadata_feat.unsqueeze(1)
global_node = self.global_init.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
h = torch.cat([image_nodes, metadata_node, global_node], dim=1)
for layer in self.layers:
residual = h
h = layer(h)
h = F.elu(h)
h = self.norm(h + residual)
fused = h[:, -1, :]
return self.proj(fused)
###############################################################################
# DenseNet201-based Classifier for Skin Lesion Classification
###############################################################################
class DenseNet201Classifier(nn.Module):
def __init__(self, num_classes=6, metadata_input_size=3, metadata_output_size=768):
super(DenseNet201Classifier, self).__init__()
self.densenet = models.densenet201(pretrained=True)
self.num_features = self.densenet.classifier.in_features
self.img_proj = nn.Linear(self.num_features, metadata_output_size)
self.metadata_encoder = MetadataEncoder(metadata_input_size, metadata_output_size)
self.enhanced_graph_fusion = EnhancedGraphFusion(
in_features=metadata_output_size,
out_features=metadata_output_size,
num_heads=4,
num_layers=2,
dropout=0.1,
alpha=0.2
)
self.head = create_classification_head(metadata_output_size, num_classes)
def forward_feature_map(self, x):
features = self.densenet.features(x)
return F.relu(features, inplace=True)
def forward(self, x, metadata):
fmap = self.forward_feature_map(x)
batch_size, C, H, W = fmap.shape
image_nodes = fmap.view(batch_size, C, H * W).transpose(1, 2)
image_nodes = self.img_proj(image_nodes)
metadata_features = self.metadata_encoder(metadata)
fused_features = self.enhanced_graph_fusion(image_nodes, metadata_features)
return self.head(fused_features)
###############################################################################
# Inference Pipeline Setup
###############################################################################
# Define image transformation for inference
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def load_model(model_path='pd4_model.pth'):
"""Load the pre-trained DenseNet201Classifier model."""
num_classes = 6
metadata_input_size = 3 # Expecting: age, gender, skin_cancer_history
model = DenseNet201Classifier(num_classes=num_classes,
metadata_input_size=metadata_input_size,
metadata_output_size=768)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
# Load the model once at function startup
model = load_model()
###############################################################################
# Cloud Function HTTP Endpoint (Using functions_framework)
###############################################################################
@functions_framework.http
def predict(request):
"""
HTTP-triggered Cloud Function that accepts a POST request with an image file.
Optionally, you can uncomment the preprocessing step if needed.
Returns the predicted class and confidence as JSON.
"""
if request.method != 'POST':
return jsonify({'error': 'Only POST method is supported.'}), 405
if 'file' not in request.files:
return jsonify({'error': 'No file provided.'}), 400
file = request.files['file']
try:
image = Image.open(file.stream).convert('RGB')
except Exception as e:
return jsonify({'error': 'Invalid image file.'}), 400
# Optionally apply advanced preprocessing:
image = preprocess_image(image)
image_tensor = transform(image).unsqueeze(0).to(device)
# Use default metadata values (e.g., zeros for [age, gender, skin_cancer_history])
default_metadata = torch.zeros((1, 3), dtype=torch.float).to(device)
with torch.no_grad():
outputs = model(image_tensor, default_metadata)
probabilities = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, dim=1)
result = {
'predicted_class': predicted.item(),
'confidence': confidence.item()
}
return jsonify(result)
if __name__ == '__main__':
import os
port = int(os.environ.get("PORT", 8080))
# The functions_framework creates a Flask app that exposes your target function.
from functions_framework import create_app
app = create_app(target="predict")
app.run(host="0.0.0.0", port=port, debug=True)