Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| # --- Performance Logging Start | |
| start_import = time.time() | |
| from flask import Flask, request, jsonify, Response | |
| import cv2 | |
| import numpy as np | |
| from os.path import join, dirname | |
| # ── PyTorch | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms, models | |
| from rembg import remove, new_session | |
| from PIL import Image | |
| import io | |
| from flask import send_file | |
| # Set model path for rembg | |
| os.environ['U2NET_HOME'] = join(dirname(__file__), '.u2net') | |
| # Import your custom enhancement functions | |
| try: | |
| from enhancer import basicEnhancing, advancedEnhancing | |
| except ImportError: | |
| print("Warning: 'enhancer' module not found.") | |
| print(f"Imports loaded in: {time.time() - start_import:.2f}s") | |
| app = Flask(__name__) | |
| # MODEL ARCHITECTURE (ResNet50 + ArcFace ) | |
| class FingerprintCNN(nn.Module): | |
| """ResNet50 backbone for 1-channel fingerprint images -> 128-dim embedding.""" | |
| def __init__(self, embedding_dim=128): | |
| super().__init__() | |
| resnet = models.resnet50(weights=None) | |
| self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| self.bn1 = resnet.bn1 | |
| self.relu = resnet.relu | |
| self.maxpool = resnet.maxpool | |
| self.layer1 = resnet.layer1 | |
| self.layer2 = resnet.layer2 | |
| self.layer3 = resnet.layer3 | |
| self.layer4 = resnet.layer4 | |
| self.avgpool = resnet.avgpool | |
| self.embedding = nn.Sequential( | |
| nn.Linear(2048, 512), nn.BatchNorm1d(512), | |
| nn.ReLU(inplace=True), nn.Dropout(0.3), | |
| nn.Linear(512, embedding_dim), nn.BatchNorm1d(embedding_dim), | |
| ) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| x = self.embedding(x) | |
| return F.normalize(x, p=2, dim=1) | |
| class SiameseNetwork(nn.Module): | |
| def __init__(self, embedding_dim=128): | |
| super().__init__() | |
| self.backbone = FingerprintCNN(embedding_dim) | |
| def forward_one(self, x): | |
| return self.backbone(x) | |
| # CONFIGURATION | |
| DIRNAME = dirname(__file__) | |
| MODEL_PATH = join(DIRNAME, 'model', 'resnet50_arcface.pth') # Your trained model | |
| IMAGE_SIZE = (224, 224) | |
| EMBEDDING_DIM = 128 | |
| # Thresholds | |
| # Update ZERO_FA_THRESHOLD after running find_threshold.py on your data | |
| # EER_THRESHOLD = 0.4505 # Balanced: 97% accuracy, 3% EER | |
| # ZERO_FA_THRESHOLD = 0.5951 # Strict: 0 false accepts 84 % accuracy | |
| # MATCH_THRESHOLD = ZERO_FA_THRESHOLD # Default: strict (zero false accepts) | |
| MATCH_THRESHOLD = 0.75 | |
| # Image transform | |
| img_transform = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize(IMAGE_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ]) | |
| # GLOBAL STATE | |
| siamese_model = None | |
| rembg_session = new_session("u2net") | |
| # MODEL LOADING & WARMUP | |
| def load_and_warmup_model(): | |
| global siamese_model | |
| t_start = time.time() | |
| try: | |
| print(f"Loading model from: {MODEL_PATH}") | |
| siamese_model = SiameseNetwork(EMBEDDING_DIM) | |
| state_dict = torch.load(MODEL_PATH, map_location='cpu', weights_only=False) | |
| siamese_model.load_state_dict(state_dict) | |
| siamese_model.eval() | |
| # Warmup: run dummy inference so first request isn't slow | |
| print("Warming up model...") | |
| dummy = torch.zeros(1, 1, 224, 224) | |
| with torch.no_grad(): | |
| siamese_model.forward_one(dummy) | |
| print(f"Model loaded and warmed up in: {time.time() - t_start:.2f}s") | |
| except Exception as e: | |
| print(f"Critical Error: Could not load model: {e}") | |
| siamese_model = None | |
| # Load model on startup | |
| load_and_warmup_model() | |
| # PREPROCESSING | |
| def preprocess_fingerprint(image_bytes, image_type): | |
| """ | |
| Preprocess fingerprint image for the ResNet50-ArcFace model. | |
| Applies the EXACT same pipeline used during Kaggle training: | |
| Contactless: grayscale -> CLAHE -> resize 224x224 (NOte: for testing PolyU data we used == rotate 90 CCW -> mirror -> ) | |
| Contact-based: grayscale -> CLAHE -> resize 224x224 | |
| Then: tensor transform with normalize(0.5, 0.5) | |
| Returns: | |
| PyTorch tensor [1, 1, 224, 224] or None on failure | |
| """ | |
| if not image_bytes: | |
| return None | |
| # Decode bytes to cv2 image | |
| nparr = np.frombuffer(image_bytes, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) | |
| if img is None: | |
| return None | |
| # # Contactless alignment | |
| # if image_type == 'contactless': | |
| # img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| # img = cv2.flip(img, 1) # horizontal mirror | |
| # CLAHE enhancement (clipLimit=3.0, tileGrid=8x8 — ) | |
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) | |
| img = clahe.apply(img) | |
| # Resize to model input | |
| img = cv2.resize(img, IMAGE_SIZE) | |
| # To tensor [1, 1, 224, 224] | |
| tensor = img_transform(img).unsqueeze(0) | |
| return tensor | |
| def compute_match_score(emb1, emb2): | |
| """ | |
| Compute matching score from two 128-dim embeddings. | |
| Returns score in [0, 1] where higher = more similar. | |
| """ | |
| distance = float(np.linalg.norm(emb1 - emb2)) | |
| cos_sim = float(np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2) + 1e-8)) | |
| dist_score = float(np.exp(-distance)) | |
| cos_score = float((cos_sim + 1) / 2) | |
| final_score = 0.5 * dist_score + 0.5 * cos_score | |
| return final_score, cos_sim, distance | |
| # BACKGROUND REMOVAL for enhancement | |
| def remove_bg_from_cv2(cv2_img): | |
| """Converts CV2 image to PIL, removes background, returns CV2 (BGR).""" | |
| img_rgb = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(img_rgb) | |
| output_pil = remove(pil_img, session=rembg_session) | |
| output_cv2 = cv2.cvtColor(np.array(output_pil), cv2.COLOR_RGBA2BGR) | |
| return output_cv2 | |
| # ROUTES | |
| def health_check(): | |
| return jsonify({ | |
| "status": "online", | |
| "message": "Fingerprint Comparison API is running", | |
| "model_loaded": siamese_model is not None, | |
| "model_type": "ResNet50-ArcFace (PyTorch)", | |
| "threshold": MATCH_THRESHOLD, | |
| }), 200 | |
| def compare_fingerprints(): | |
| request_start = time.time() | |
| # Expecting: img_1 (file), img_2 (file), type_1 (text), type_2 (text) | |
| if 'img_1' not in request.files or 'img_2' not in request.files: | |
| return jsonify({"error": "Missing image files"}), 400 | |
| if siamese_model is None: | |
| return jsonify({"error": "Model not loaded"}), 500 | |
| # 1. Read Files | |
| t_read = time.time() | |
| img_1_file = request.files['img_1'].read() | |
| img_2_file = request.files['img_2'].read() | |
| # Defaults | |
| type_1 = request.form.get('type_1', 'contactless') | |
| type_2 = request.form.get('type_2', 'contactbased') | |
| print(f"File Read Time: {time.time() - t_read:.4f}s") | |
| # 2. Preprocess | |
| # Timer for Image 1 | |
| t_p1 = time.time() | |
| processed_1 = preprocess_fingerprint(img_1_file, type_1) | |
| time_p1 = time.time() - t_p1 | |
| print(f"Preprocessing Image 1 ({type_1}): {time_p1:.4f}s") | |
| # Timer for Image 2 | |
| t_p2 = time.time() | |
| processed_2 = preprocess_fingerprint(img_2_file, type_2) | |
| time_p2 = time.time() - t_p2 | |
| print(f"Preprocessing Image 2 ({type_2}): {time_p2:.4f}s") | |
| if processed_1 is None or processed_2 is None: | |
| return jsonify({"error": "Failed to process images"}), 400 | |
| # 3. Perform Inference — get embeddings and compute score | |
| t_inf = time.time() | |
| with torch.no_grad(): | |
| emb1 = siamese_model.forward_one(processed_1).numpy().flatten() | |
| emb2 = siamese_model.forward_one(processed_2).numpy().flatten() | |
| score, cosine_sim, euclidean_dist = compute_match_score(emb1, emb2) | |
| time_inf = time.time() - t_inf | |
| print(f"Inference Time: {time_inf:.4f}s") | |
| # Match percentage mapped to [0, 100] | |
| match_percentage = round(score * 100, 2) | |
| total_time = time.time() - request_start | |
| print(f"Total Request Time: {total_time:.4f}s") | |
| # 4. Return JSON response | |
| return jsonify({ | |
| "score": round(score, 6), | |
| "match_percentage": match_percentage, | |
| "is_match": score >= MATCH_THRESHOLD, | |
| "metadata": { | |
| "img_1_type": type_1, | |
| "img_2_type": type_2, | |
| "threshold_used": MATCH_THRESHOLD, | |
| "cosine_similarity": round(cosine_sim, 6), | |
| "euclidean_distance": round(euclidean_dist, 6), | |
| } | |
| }) | |
| def enhance_image_api(): | |
| overall_start = time.time() | |
| if 'image' not in request.files: | |
| return jsonify({"error": "No image file provided."}), 400 | |
| file = request.files['image'] | |
| try: | |
| # Decode Image | |
| t_start = time.time() | |
| file_bytes = np.frombuffer(file.read(), np.uint8) | |
| img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
| if img is None: | |
| return jsonify({"error": "Failed to decode image"}), 400 | |
| print(f"[TIMING] Image Decode: {time.time() - t_start:.4f}s") | |
| # Background Removal | |
| t_bg = time.time() | |
| no_bg = remove_bg_from_cv2(img) | |
| print(f"[TIMING] Background Removal: {time.time() - t_bg:.4f}s") | |
| # Basic Enhancement | |
| t_basic = time.time() | |
| step1 = basicEnhancing(no_bg) | |
| print(f"[TIMING] Basic Enhancement: {time.time() - t_basic:.4f}s") | |
| # Advanced Enhancement | |
| t_adv = time.time() | |
| final_enhanced = advancedEnhancing(step1) | |
| print(f"[TIMING] Advanced Enhancement: {time.time() - t_adv:.4f}s") | |
| # Encode to JPEG | |
| t_enc = time.time() | |
| success, encoded_image = cv2.imencode('.jpg', final_enhanced, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) | |
| if not success: | |
| return jsonify({"error": "Encoding failed"}), 500 | |
| img_io = io.BytesIO(encoded_image.tobytes()) | |
| img_io.seek(0) | |
| print(f"[TIMING] Save to file: {time.time() - t_enc:.4f}s") | |
| total_time = time.time() - overall_start | |
| print(f"[TIMING] TOTAL ENHANCE REQUEST: {total_time:.4f}s") | |
| return send_file( | |
| img_io, | |
| mimetype='image/jpeg', | |
| as_attachment=False, | |
| download_name='enhanced.jpg' | |
| ) | |
| except Exception as e: | |
| print(f"Enhancement Error: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=5000, debug=False) |