|
|
import os |
|
|
import io |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from flask import jsonify, request |
|
|
import functions_framework |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_hair_and_markings(img_bgr, kernel_size=17): |
|
|
"""Remove thin hair and markings using morphological operations.""" |
|
|
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) |
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size)) |
|
|
blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, kernel) |
|
|
_, mask = cv2.threshold(blackhat, 10, 255, cv2.THRESH_BINARY) |
|
|
inpainted = cv2.inpaint(img_bgr, mask, 3, cv2.INPAINT_TELEA) |
|
|
return inpainted |
|
|
|
|
|
def normalize_color(img_bgr): |
|
|
"""Per-channel normalization: subtract mean and divide by std for each channel.""" |
|
|
img_float = img_bgr.astype(np.float32) |
|
|
b, g, r = cv2.split(img_float) |
|
|
for channel in (b, g, r): |
|
|
mean_val = np.mean(channel) |
|
|
std_val = np.std(channel) + 1e-8 |
|
|
channel[:] = (channel - mean_val) / std_val |
|
|
normalized = cv2.merge([b, g, r]) |
|
|
return normalized |
|
|
|
|
|
def mmwf_filter(gray_image, window_size=3, noise_variance=0.01): |
|
|
"""Apply the Median–Modified Wiener Filter (MMWF) on a grayscale image.""" |
|
|
if gray_image.dtype != np.float32: |
|
|
gray_image = gray_image.astype(np.float32) |
|
|
pad_size = window_size // 2 |
|
|
padded = np.pad(gray_image, pad_size, mode='reflect') |
|
|
filtered = np.zeros_like(gray_image, dtype=np.float32) |
|
|
rows, cols = gray_image.shape |
|
|
for i in range(rows): |
|
|
for j in range(cols): |
|
|
local_patch = padded[i:i+window_size, j:j+window_size] |
|
|
mu_m = np.median(local_patch) |
|
|
sigma_sq = local_patch.var() |
|
|
a_val = gray_image[i, j] |
|
|
if sigma_sq < 1e-12: |
|
|
filtered[i, j] = mu_m |
|
|
else: |
|
|
filtered[i, j] = mu_m + ((sigma_sq - noise_variance) / sigma_sq) * (a_val - mu_m) |
|
|
return filtered |
|
|
|
|
|
def mmwf_filter_color(img_bgr, window_size=3, noise_variance=0.01): |
|
|
"""Apply MMWF to each channel of a BGR image.""" |
|
|
if img_bgr.dtype != np.float32: |
|
|
img_bgr = img_bgr.astype(np.float32) |
|
|
b, g, r = cv2.split(img_bgr) |
|
|
b_denoised = mmwf_filter(b, window_size, noise_variance) |
|
|
g_denoised = mmwf_filter(g, window_size, noise_variance) |
|
|
r_denoised = mmwf_filter(r, window_size, noise_variance) |
|
|
denoised_bgr = cv2.merge([b_denoised, g_denoised, r_denoised]) |
|
|
return denoised_bgr |
|
|
|
|
|
def preprocess_image(image): |
|
|
""" |
|
|
Optionally, apply preprocessing steps: |
|
|
1. Convert from PIL RGB to NumPy BGR. |
|
|
2. Remove hair/markings. |
|
|
3. Normalize color. |
|
|
4. Apply MMWF filter. |
|
|
5. Convert back to PIL RGB. |
|
|
""" |
|
|
image_np = np.array(image) |
|
|
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
|
|
image_bgr = remove_hair_and_markings(image_bgr, kernel_size=17) |
|
|
image_bgr = normalize_color(image_bgr) |
|
|
image_bgr = cv2.normalize(image_bgr, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX).astype(np.uint8) |
|
|
|
|
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) |
|
|
return Image.fromarray(image_rgb) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_classification_head(input_dim, num_classes): |
|
|
return nn.Sequential( |
|
|
nn.Linear(input_dim, 512), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.5), |
|
|
nn.Linear(512, num_classes) |
|
|
) |
|
|
|
|
|
class MetadataEncoder(nn.Module): |
|
|
def __init__(self, input_size, output_size): |
|
|
super(MetadataEncoder, self).__init__() |
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv1d(1, 16, kernel_size=3, padding=1), |
|
|
nn.BatchNorm1d(16), |
|
|
nn.ReLU(), |
|
|
nn.Conv1d(16, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm1d(32), |
|
|
nn.ReLU(), |
|
|
nn.Flatten(), |
|
|
nn.Linear(32 * input_size, output_size), |
|
|
nn.ReLU() |
|
|
) |
|
|
def forward(self, x): |
|
|
x = x.unsqueeze(1) |
|
|
return self.encoder(x) |
|
|
|
|
|
class GraphAttentionLayer(nn.Module): |
|
|
def __init__(self, in_features, out_features, dropout=0.0, alpha=0.2): |
|
|
super(GraphAttentionLayer, self).__init__() |
|
|
self.W = nn.Linear(in_features, out_features, bias=False) |
|
|
self.a = nn.Parameter(torch.empty(2 * out_features, 1)) |
|
|
nn.init.xavier_uniform_(self.W.weight.data, gain=1.414) |
|
|
nn.init.xavier_uniform_(self.a.data, gain=1.414) |
|
|
self.leakyrelu = nn.LeakyReLU(alpha) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
def forward(self, h, adj=None): |
|
|
Wh = self.W(h) |
|
|
batch_size, N, _ = Wh.size() |
|
|
Wh_i = Wh.unsqueeze(2).repeat(1, 1, N, 1) |
|
|
Wh_j = Wh.unsqueeze(1).repeat(1, N, 1, 1) |
|
|
e = self.leakyrelu(torch.matmul(torch.cat([Wh_i, Wh_j], dim=-1), self.a).squeeze(-1)) |
|
|
attention = F.softmax(e, dim=-1) |
|
|
attention = self.dropout(attention) |
|
|
h_prime = torch.matmul(attention, Wh) |
|
|
return h_prime, attention |
|
|
|
|
|
class MultiHeadGraphAttentionLayer(nn.Module): |
|
|
def __init__(self, in_features, out_features, num_heads=4, dropout=0.0, alpha=0.2): |
|
|
super(MultiHeadGraphAttentionLayer, self).__init__() |
|
|
self.num_heads = num_heads |
|
|
self.heads = nn.ModuleList([ |
|
|
GraphAttentionLayer(in_features, out_features, dropout, alpha) |
|
|
for _ in range(num_heads) |
|
|
]) |
|
|
self.linear = nn.Linear(num_heads * out_features, out_features) |
|
|
def forward(self, h): |
|
|
head_outputs = [head(h)[0] for head in self.heads] |
|
|
h_concat = torch.cat(head_outputs, dim=-1) |
|
|
return self.linear(h_concat) |
|
|
|
|
|
class EnhancedGraphFusion(nn.Module): |
|
|
def __init__(self, in_features, out_features, num_heads=4, num_layers=2, dropout=0.0, alpha=0.2): |
|
|
super(EnhancedGraphFusion, self).__init__() |
|
|
self.global_init = nn.Parameter(torch.zeros(in_features)) |
|
|
self.layers = nn.ModuleList([ |
|
|
MultiHeadGraphAttentionLayer(in_features, in_features, num_heads, dropout, alpha) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
self.norm = nn.LayerNorm(in_features) |
|
|
self.proj = nn.Linear(in_features, out_features) |
|
|
def forward(self, image_nodes, metadata_feat): |
|
|
batch_size = image_nodes.size(0) |
|
|
metadata_node = metadata_feat.unsqueeze(1) |
|
|
global_node = self.global_init.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) |
|
|
h = torch.cat([image_nodes, metadata_node, global_node], dim=1) |
|
|
for layer in self.layers: |
|
|
residual = h |
|
|
h = layer(h) |
|
|
h = F.elu(h) |
|
|
h = self.norm(h + residual) |
|
|
fused = h[:, -1, :] |
|
|
return self.proj(fused) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DenseNet201Classifier(nn.Module): |
|
|
def __init__(self, num_classes=6, metadata_input_size=3, metadata_output_size=768): |
|
|
super(DenseNet201Classifier, self).__init__() |
|
|
self.densenet = models.densenet201(pretrained=True) |
|
|
self.num_features = self.densenet.classifier.in_features |
|
|
self.img_proj = nn.Linear(self.num_features, metadata_output_size) |
|
|
self.metadata_encoder = MetadataEncoder(metadata_input_size, metadata_output_size) |
|
|
self.enhanced_graph_fusion = EnhancedGraphFusion( |
|
|
in_features=metadata_output_size, |
|
|
out_features=metadata_output_size, |
|
|
num_heads=4, |
|
|
num_layers=2, |
|
|
dropout=0.1, |
|
|
alpha=0.2 |
|
|
) |
|
|
self.head = create_classification_head(metadata_output_size, num_classes) |
|
|
def forward_feature_map(self, x): |
|
|
features = self.densenet.features(x) |
|
|
return F.relu(features, inplace=True) |
|
|
def forward(self, x, metadata): |
|
|
fmap = self.forward_feature_map(x) |
|
|
batch_size, C, H, W = fmap.shape |
|
|
image_nodes = fmap.view(batch_size, C, H * W).transpose(1, 2) |
|
|
image_nodes = self.img_proj(image_nodes) |
|
|
metadata_features = self.metadata_encoder(metadata) |
|
|
fused_features = self.enhanced_graph_fusion(image_nodes, metadata_features) |
|
|
return self.head(fused_features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def load_model(model_path='pd4_model.pth'): |
|
|
"""Load the pre-trained DenseNet201Classifier model.""" |
|
|
num_classes = 6 |
|
|
metadata_input_size = 3 |
|
|
model = DenseNet201Classifier(num_classes=num_classes, |
|
|
metadata_input_size=metadata_input_size, |
|
|
metadata_output_size=768) |
|
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
model = load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functions_framework.http |
|
|
def predict(request): |
|
|
""" |
|
|
HTTP-triggered Cloud Function that accepts a POST request with an image file. |
|
|
Optionally, you can uncomment the preprocessing step if needed. |
|
|
Returns the predicted class and confidence as JSON. |
|
|
""" |
|
|
if request.method != 'POST': |
|
|
return jsonify({'error': 'Only POST method is supported.'}), 405 |
|
|
|
|
|
if 'file' not in request.files: |
|
|
return jsonify({'error': 'No file provided.'}), 400 |
|
|
|
|
|
file = request.files['file'] |
|
|
try: |
|
|
image = Image.open(file.stream).convert('RGB') |
|
|
except Exception as e: |
|
|
return jsonify({'error': 'Invalid image file.'}), 400 |
|
|
|
|
|
|
|
|
image = preprocess_image(image) |
|
|
|
|
|
image_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
default_metadata = torch.zeros((1, 3), dtype=torch.float).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor, default_metadata) |
|
|
probabilities = F.softmax(outputs, dim=1) |
|
|
confidence, predicted = torch.max(probabilities, dim=1) |
|
|
|
|
|
result = { |
|
|
'predicted_class': predicted.item(), |
|
|
'confidence': confidence.item() |
|
|
} |
|
|
return jsonify(result) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
import os |
|
|
port = int(os.environ.get("PORT", 8080)) |
|
|
|
|
|
from functions_framework import create_app |
|
|
app = create_app(target="predict") |
|
|
app.run(host="0.0.0.0", port=port, debug=True) |
|
|
|