ctcfpm / app.py
sol9x-sagar's picture
bug: PolyU data (rotate)
7dbb617
import os
import time
# --- Performance Logging Start
start_import = time.time()
from flask import Flask, request, jsonify, Response
import cv2
import numpy as np
from os.path import join, dirname
# ── PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from rembg import remove, new_session
from PIL import Image
import io
from flask import send_file
# Set model path for rembg
os.environ['U2NET_HOME'] = join(dirname(__file__), '.u2net')
# Import your custom enhancement functions
try:
from enhancer import basicEnhancing, advancedEnhancing
except ImportError:
print("Warning: 'enhancer' module not found.")
print(f"Imports loaded in: {time.time() - start_import:.2f}s")
app = Flask(__name__)
# MODEL ARCHITECTURE (ResNet50 + ArcFace )
class FingerprintCNN(nn.Module):
"""ResNet50 backbone for 1-channel fingerprint images -> 128-dim embedding."""
def __init__(self, embedding_dim=128):
super().__init__()
resnet = models.resnet50(weights=None)
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
self.avgpool = resnet.avgpool
self.embedding = nn.Sequential(
nn.Linear(2048, 512), nn.BatchNorm1d(512),
nn.ReLU(inplace=True), nn.Dropout(0.3),
nn.Linear(512, embedding_dim), nn.BatchNorm1d(embedding_dim),
)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.embedding(x)
return F.normalize(x, p=2, dim=1)
class SiameseNetwork(nn.Module):
def __init__(self, embedding_dim=128):
super().__init__()
self.backbone = FingerprintCNN(embedding_dim)
def forward_one(self, x):
return self.backbone(x)
# CONFIGURATION
DIRNAME = dirname(__file__)
MODEL_PATH = join(DIRNAME, 'model', 'resnet50_arcface.pth') # Your trained model
IMAGE_SIZE = (224, 224)
EMBEDDING_DIM = 128
# Thresholds
# Update ZERO_FA_THRESHOLD after running find_threshold.py on your data
# EER_THRESHOLD = 0.4505 # Balanced: 97% accuracy, 3% EER
# ZERO_FA_THRESHOLD = 0.5951 # Strict: 0 false accepts 84 % accuracy
# MATCH_THRESHOLD = ZERO_FA_THRESHOLD # Default: strict (zero false accepts)
MATCH_THRESHOLD = 0.75
# Image transform
img_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# GLOBAL STATE
siamese_model = None
rembg_session = new_session("u2net")
# MODEL LOADING & WARMUP
def load_and_warmup_model():
global siamese_model
t_start = time.time()
try:
print(f"Loading model from: {MODEL_PATH}")
siamese_model = SiameseNetwork(EMBEDDING_DIM)
state_dict = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
siamese_model.load_state_dict(state_dict)
siamese_model.eval()
# Warmup: run dummy inference so first request isn't slow
print("Warming up model...")
dummy = torch.zeros(1, 1, 224, 224)
with torch.no_grad():
siamese_model.forward_one(dummy)
print(f"Model loaded and warmed up in: {time.time() - t_start:.2f}s")
except Exception as e:
print(f"Critical Error: Could not load model: {e}")
siamese_model = None
# Load model on startup
load_and_warmup_model()
# PREPROCESSING
def preprocess_fingerprint(image_bytes, image_type):
"""
Preprocess fingerprint image for the ResNet50-ArcFace model.
Applies the EXACT same pipeline used during Kaggle training:
Contactless: grayscale -> CLAHE -> resize 224x224 (NOte: for testing PolyU data we used == rotate 90 CCW -> mirror -> )
Contact-based: grayscale -> CLAHE -> resize 224x224
Then: tensor transform with normalize(0.5, 0.5)
Returns:
PyTorch tensor [1, 1, 224, 224] or None on failure
"""
if not image_bytes:
return None
# Decode bytes to cv2 image
nparr = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
if img is None:
return None
# # Contactless alignment
# if image_type == 'contactless':
# img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
# img = cv2.flip(img, 1) # horizontal mirror
# CLAHE enhancement (clipLimit=3.0, tileGrid=8x8 — )
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
img = clahe.apply(img)
# Resize to model input
img = cv2.resize(img, IMAGE_SIZE)
# To tensor [1, 1, 224, 224]
tensor = img_transform(img).unsqueeze(0)
return tensor
def compute_match_score(emb1, emb2):
"""
Compute matching score from two 128-dim embeddings.
Returns score in [0, 1] where higher = more similar.
"""
distance = float(np.linalg.norm(emb1 - emb2))
cos_sim = float(np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2) + 1e-8))
dist_score = float(np.exp(-distance))
cos_score = float((cos_sim + 1) / 2)
final_score = 0.5 * dist_score + 0.5 * cos_score
return final_score, cos_sim, distance
# BACKGROUND REMOVAL for enhancement
def remove_bg_from_cv2(cv2_img):
"""Converts CV2 image to PIL, removes background, returns CV2 (BGR)."""
img_rgb = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
output_pil = remove(pil_img, session=rembg_session)
output_cv2 = cv2.cvtColor(np.array(output_pil), cv2.COLOR_RGBA2BGR)
return output_cv2
# ROUTES
@app.route('/', methods=['GET'])
def health_check():
return jsonify({
"status": "online",
"message": "Fingerprint Comparison API is running",
"model_loaded": siamese_model is not None,
"model_type": "ResNet50-ArcFace (PyTorch)",
"threshold": MATCH_THRESHOLD,
}), 200
@app.route('/compare', methods=['POST'])
def compare_fingerprints():
request_start = time.time()
# Expecting: img_1 (file), img_2 (file), type_1 (text), type_2 (text)
if 'img_1' not in request.files or 'img_2' not in request.files:
return jsonify({"error": "Missing image files"}), 400
if siamese_model is None:
return jsonify({"error": "Model not loaded"}), 500
# 1. Read Files
t_read = time.time()
img_1_file = request.files['img_1'].read()
img_2_file = request.files['img_2'].read()
# Defaults
type_1 = request.form.get('type_1', 'contactless')
type_2 = request.form.get('type_2', 'contactbased')
print(f"File Read Time: {time.time() - t_read:.4f}s")
# 2. Preprocess
# Timer for Image 1
t_p1 = time.time()
processed_1 = preprocess_fingerprint(img_1_file, type_1)
time_p1 = time.time() - t_p1
print(f"Preprocessing Image 1 ({type_1}): {time_p1:.4f}s")
# Timer for Image 2
t_p2 = time.time()
processed_2 = preprocess_fingerprint(img_2_file, type_2)
time_p2 = time.time() - t_p2
print(f"Preprocessing Image 2 ({type_2}): {time_p2:.4f}s")
if processed_1 is None or processed_2 is None:
return jsonify({"error": "Failed to process images"}), 400
# 3. Perform Inference — get embeddings and compute score
t_inf = time.time()
with torch.no_grad():
emb1 = siamese_model.forward_one(processed_1).numpy().flatten()
emb2 = siamese_model.forward_one(processed_2).numpy().flatten()
score, cosine_sim, euclidean_dist = compute_match_score(emb1, emb2)
time_inf = time.time() - t_inf
print(f"Inference Time: {time_inf:.4f}s")
# Match percentage mapped to [0, 100]
match_percentage = round(score * 100, 2)
total_time = time.time() - request_start
print(f"Total Request Time: {total_time:.4f}s")
# 4. Return JSON response
return jsonify({
"score": round(score, 6),
"match_percentage": match_percentage,
"is_match": score >= MATCH_THRESHOLD,
"metadata": {
"img_1_type": type_1,
"img_2_type": type_2,
"threshold_used": MATCH_THRESHOLD,
"cosine_similarity": round(cosine_sim, 6),
"euclidean_distance": round(euclidean_dist, 6),
}
})
@app.route('/enhance', methods=['POST'])
def enhance_image_api():
overall_start = time.time()
if 'image' not in request.files:
return jsonify({"error": "No image file provided."}), 400
file = request.files['image']
try:
# Decode Image
t_start = time.time()
file_bytes = np.frombuffer(file.read(), np.uint8)
img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
if img is None:
return jsonify({"error": "Failed to decode image"}), 400
print(f"[TIMING] Image Decode: {time.time() - t_start:.4f}s")
# Background Removal
t_bg = time.time()
no_bg = remove_bg_from_cv2(img)
print(f"[TIMING] Background Removal: {time.time() - t_bg:.4f}s")
# Basic Enhancement
t_basic = time.time()
step1 = basicEnhancing(no_bg)
print(f"[TIMING] Basic Enhancement: {time.time() - t_basic:.4f}s")
# Advanced Enhancement
t_adv = time.time()
final_enhanced = advancedEnhancing(step1)
print(f"[TIMING] Advanced Enhancement: {time.time() - t_adv:.4f}s")
# Encode to JPEG
t_enc = time.time()
success, encoded_image = cv2.imencode('.jpg', final_enhanced, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
if not success:
return jsonify({"error": "Encoding failed"}), 500
img_io = io.BytesIO(encoded_image.tobytes())
img_io.seek(0)
print(f"[TIMING] Save to file: {time.time() - t_enc:.4f}s")
total_time = time.time() - overall_start
print(f"[TIMING] TOTAL ENHANCE REQUEST: {total_time:.4f}s")
return send_file(
img_io,
mimetype='image/jpeg',
as_attachment=False,
download_name='enhanced.jpg'
)
except Exception as e:
print(f"Enhancement Error: {e}")
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=False)