import os import io import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, models from PIL import Image import cv2 import numpy as np from flask import jsonify, request import functions_framework # Set device (Cloud Functions generally use CPU, but GPU will be used if available) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ############################################################################### # Utility Functions for Image Preprocessing (Optional) ############################################################################### def remove_hair_and_markings(img_bgr, kernel_size=17): """Remove thin hair and markings using morphological operations.""" gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size)) blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, kernel) _, mask = cv2.threshold(blackhat, 10, 255, cv2.THRESH_BINARY) inpainted = cv2.inpaint(img_bgr, mask, 3, cv2.INPAINT_TELEA) return inpainted def normalize_color(img_bgr): """Per-channel normalization: subtract mean and divide by std for each channel.""" img_float = img_bgr.astype(np.float32) b, g, r = cv2.split(img_float) for channel in (b, g, r): mean_val = np.mean(channel) std_val = np.std(channel) + 1e-8 channel[:] = (channel - mean_val) / std_val normalized = cv2.merge([b, g, r]) return normalized def mmwf_filter(gray_image, window_size=3, noise_variance=0.01): """Apply the Median–Modified Wiener Filter (MMWF) on a grayscale image.""" if gray_image.dtype != np.float32: gray_image = gray_image.astype(np.float32) pad_size = window_size // 2 padded = np.pad(gray_image, pad_size, mode='reflect') filtered = np.zeros_like(gray_image, dtype=np.float32) rows, cols = gray_image.shape for i in range(rows): for j in range(cols): local_patch = padded[i:i+window_size, j:j+window_size] mu_m = np.median(local_patch) sigma_sq = local_patch.var() a_val = gray_image[i, j] if sigma_sq < 1e-12: filtered[i, j] = mu_m else: filtered[i, j] = mu_m + ((sigma_sq - noise_variance) / sigma_sq) * (a_val - mu_m) return filtered def mmwf_filter_color(img_bgr, window_size=3, noise_variance=0.01): """Apply MMWF to each channel of a BGR image.""" if img_bgr.dtype != np.float32: img_bgr = img_bgr.astype(np.float32) b, g, r = cv2.split(img_bgr) b_denoised = mmwf_filter(b, window_size, noise_variance) g_denoised = mmwf_filter(g, window_size, noise_variance) r_denoised = mmwf_filter(r, window_size, noise_variance) denoised_bgr = cv2.merge([b_denoised, g_denoised, r_denoised]) return denoised_bgr def preprocess_image(image): """ Optionally, apply preprocessing steps: 1. Convert from PIL RGB to NumPy BGR. 2. Remove hair/markings. 3. Normalize color. 4. Apply MMWF filter. 5. Convert back to PIL RGB. """ image_np = np.array(image) image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) image_bgr = remove_hair_and_markings(image_bgr, kernel_size=17) image_bgr = normalize_color(image_bgr) image_bgr = cv2.normalize(image_bgr, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX).astype(np.uint8) # image_bgr = mmwf_filter_color(image_bgr, window_size=3, noise_variance=0.01) image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) return Image.fromarray(image_rgb) ############################################################################### # Helper Modules for the Classifier ############################################################################### def create_classification_head(input_dim, num_classes): return nn.Sequential( nn.Linear(input_dim, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ) class MetadataEncoder(nn.Module): def __init__(self, input_size, output_size): super(MetadataEncoder, self).__init__() self.encoder = nn.Sequential( nn.Conv1d(1, 16, kernel_size=3, padding=1), nn.BatchNorm1d(16), nn.ReLU(), nn.Conv1d(16, 32, kernel_size=3, padding=1), nn.BatchNorm1d(32), nn.ReLU(), nn.Flatten(), nn.Linear(32 * input_size, output_size), nn.ReLU() ) def forward(self, x): x = x.unsqueeze(1) return self.encoder(x) class GraphAttentionLayer(nn.Module): def __init__(self, in_features, out_features, dropout=0.0, alpha=0.2): super(GraphAttentionLayer, self).__init__() self.W = nn.Linear(in_features, out_features, bias=False) self.a = nn.Parameter(torch.empty(2 * out_features, 1)) nn.init.xavier_uniform_(self.W.weight.data, gain=1.414) nn.init.xavier_uniform_(self.a.data, gain=1.414) self.leakyrelu = nn.LeakyReLU(alpha) self.dropout = nn.Dropout(dropout) def forward(self, h, adj=None): Wh = self.W(h) batch_size, N, _ = Wh.size() Wh_i = Wh.unsqueeze(2).repeat(1, 1, N, 1) Wh_j = Wh.unsqueeze(1).repeat(1, N, 1, 1) e = self.leakyrelu(torch.matmul(torch.cat([Wh_i, Wh_j], dim=-1), self.a).squeeze(-1)) attention = F.softmax(e, dim=-1) attention = self.dropout(attention) h_prime = torch.matmul(attention, Wh) return h_prime, attention class MultiHeadGraphAttentionLayer(nn.Module): def __init__(self, in_features, out_features, num_heads=4, dropout=0.0, alpha=0.2): super(MultiHeadGraphAttentionLayer, self).__init__() self.num_heads = num_heads self.heads = nn.ModuleList([ GraphAttentionLayer(in_features, out_features, dropout, alpha) for _ in range(num_heads) ]) self.linear = nn.Linear(num_heads * out_features, out_features) def forward(self, h): head_outputs = [head(h)[0] for head in self.heads] h_concat = torch.cat(head_outputs, dim=-1) return self.linear(h_concat) class EnhancedGraphFusion(nn.Module): def __init__(self, in_features, out_features, num_heads=4, num_layers=2, dropout=0.0, alpha=0.2): super(EnhancedGraphFusion, self).__init__() self.global_init = nn.Parameter(torch.zeros(in_features)) self.layers = nn.ModuleList([ MultiHeadGraphAttentionLayer(in_features, in_features, num_heads, dropout, alpha) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(in_features) self.proj = nn.Linear(in_features, out_features) def forward(self, image_nodes, metadata_feat): batch_size = image_nodes.size(0) metadata_node = metadata_feat.unsqueeze(1) global_node = self.global_init.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) h = torch.cat([image_nodes, metadata_node, global_node], dim=1) for layer in self.layers: residual = h h = layer(h) h = F.elu(h) h = self.norm(h + residual) fused = h[:, -1, :] return self.proj(fused) ############################################################################### # DenseNet201-based Classifier for Skin Lesion Classification ############################################################################### class DenseNet201Classifier(nn.Module): def __init__(self, num_classes=6, metadata_input_size=3, metadata_output_size=768): super(DenseNet201Classifier, self).__init__() self.densenet = models.densenet201(pretrained=True) self.num_features = self.densenet.classifier.in_features self.img_proj = nn.Linear(self.num_features, metadata_output_size) self.metadata_encoder = MetadataEncoder(metadata_input_size, metadata_output_size) self.enhanced_graph_fusion = EnhancedGraphFusion( in_features=metadata_output_size, out_features=metadata_output_size, num_heads=4, num_layers=2, dropout=0.1, alpha=0.2 ) self.head = create_classification_head(metadata_output_size, num_classes) def forward_feature_map(self, x): features = self.densenet.features(x) return F.relu(features, inplace=True) def forward(self, x, metadata): fmap = self.forward_feature_map(x) batch_size, C, H, W = fmap.shape image_nodes = fmap.view(batch_size, C, H * W).transpose(1, 2) image_nodes = self.img_proj(image_nodes) metadata_features = self.metadata_encoder(metadata) fused_features = self.enhanced_graph_fusion(image_nodes, metadata_features) return self.head(fused_features) ############################################################################### # Inference Pipeline Setup ############################################################################### # Define image transformation for inference transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_model(model_path='pd4_model.pth'): """Load the pre-trained DenseNet201Classifier model.""" num_classes = 6 metadata_input_size = 3 # Expecting: age, gender, skin_cancer_history model = DenseNet201Classifier(num_classes=num_classes, metadata_input_size=metadata_input_size, metadata_output_size=768) state_dict = torch.load(model_path, map_location='cpu') model.load_state_dict(state_dict) model.to(device) model.eval() return model # Load the model once at function startup model = load_model() ############################################################################### # Cloud Function HTTP Endpoint (Using functions_framework) ############################################################################### @functions_framework.http def predict(request): """ HTTP-triggered Cloud Function that accepts a POST request with an image file. Optionally, you can uncomment the preprocessing step if needed. Returns the predicted class and confidence as JSON. """ if request.method != 'POST': return jsonify({'error': 'Only POST method is supported.'}), 405 if 'file' not in request.files: return jsonify({'error': 'No file provided.'}), 400 file = request.files['file'] try: image = Image.open(file.stream).convert('RGB') except Exception as e: return jsonify({'error': 'Invalid image file.'}), 400 # Optionally apply advanced preprocessing: image = preprocess_image(image) image_tensor = transform(image).unsqueeze(0).to(device) # Use default metadata values (e.g., zeros for [age, gender, skin_cancer_history]) default_metadata = torch.zeros((1, 3), dtype=torch.float).to(device) with torch.no_grad(): outputs = model(image_tensor, default_metadata) probabilities = F.softmax(outputs, dim=1) confidence, predicted = torch.max(probabilities, dim=1) result = { 'predicted_class': predicted.item(), 'confidence': confidence.item() } return jsonify(result) if __name__ == '__main__': import os port = int(os.environ.get("PORT", 8080)) # The functions_framework creates a Flask app that exposes your target function. from functions_framework import create_app app = create_app(target="predict") app.run(host="0.0.0.0", port=port, debug=True)