Spaces:
Sleeping
Sleeping
File size: 4,911 Bytes
6f6e572 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""
AI Model Inference Module
Responsible for loading TensorFlow Keras models and performing AI/real predictions on facial images.
"""
import os
import numpy as np
from PIL import Image
from typing import Tuple, Optional
import logging
logger = logging.getLogger(__name__)
class FaceDetectorModel:
"""AI face detection model wrapper class"""
def __init__(self, model_path: str = "model/best_mobilenet_finetuned.keras"):
"""
Initialize and load AI detection model
Args:
model_path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
self.model_path = model_path
self.model = None
self.input_size = (224, 224) # Default input size
self.threshold = 0.5 # Classification threshold
try:
self._load_model()
logger.info(f"Model loaded successfully: {model_path}")
except Exception as e:
error_msg = f"Model loading failed: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _load_model(self):
"""Load TensorFlow Keras model"""
try:
import tensorflow as tf
from tensorflow import keras
# Define custom objects to help Keras recognize special layers
custom_objects = {
'Add': keras.ops.add,
'Multiply': keras.ops.multiply,
}
# Attempt to load model
try:
self.model = tf.keras.models.load_model(
self.model_path,
custom_objects=custom_objects,
compile=False
)
except Exception as e:
# If failed, try using safe_mode
logger.warning(f"Standard loading failed, trying safe mode: {e}")
self.model = tf.keras.models.load_model(
self.model_path,
safe_mode=False, # Disable safety checks
compile=False
)
# Get input shape from model
if hasattr(self.model, 'input_shape'):
input_shape = self.model.input_shape
if len(input_shape) >= 3:
self.input_size = (input_shape[1], input_shape[2])
except Exception as e:
raise RuntimeError(f"Cannot load model file {self.model_path}: {e}")
def preprocess_image(self, image_path: str) -> Optional[np.ndarray]:
"""
Preprocess image to match model input requirements
Args:
image_path: Path to image file
Returns:
Preprocessed image array, returns None if failed
"""
try:
# Load image
img = Image.open(image_path).convert('RGB')
# Resize to model input size
img = img.resize(self.input_size)
# Convert to numpy array and normalize
img_array = np.array(img, dtype=np.float32) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
except Exception as e:
logger.error(f"Image preprocessing failed {image_path}: {e}")
return None
def predict(self, image_path: str) -> Tuple[str, float]:
"""
Predict whether image is AI-generated or real human
Args:
image_path: Path to image file
Returns:
Tuple[str, float]: (prediction label, confidence)
- Prediction label: "AI" or "Human"
- Confidence: Float between 0.0-1.0
"""
# Preprocess image
img_array = self.preprocess_image(image_path)
if img_array is None:
# Preprocessing failed, return default value
return "AI", 0.5
try:
# Model prediction
prediction = self.model.predict(img_array, verbose=0)[0][0]
# Convert to label
# Assumption: model output >threshold is AI-generated, <=threshold is real human
label = "AI" if prediction > self.threshold else "Human"
confidence = float(prediction) if prediction > self.threshold else float(1 - prediction)
return label, confidence
except Exception as e:
logger.error(f"Model prediction failed {image_path}: {e}")
return "AI", 0.5
def warmup(self):
"""
Warmup model using dummy data to reduce first prediction delay
"""
try:
# Create a dummy image
dummy_img = np.random.rand(1, self.input_size[0], self.input_size[1], 3).astype(np.float32)
self.model.predict(dummy_img, verbose=0)
logger.info("Model warmup complete")
except Exception as e:
logger.warning(f"Model warmup failed: {e}")
|