BugFreeAli commited on
Commit
251a307
Β·
verified Β·
1 Parent(s): 878d9aa

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +220 -116
main.py CHANGED
@@ -1,116 +1,220 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import tensorflow as tf
4
- from tensorflow.keras.models import load_model
5
- from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
6
- from PIL import Image
7
- import numpy as np
8
- import io
9
- import os
10
-
11
- # 1. INITIALIZE FASTAPI
12
- app = FastAPI(title="Veritas AI Detector API")
13
-
14
- # 2. CORS (CRITICAL: Allows your frontend to talk to this backend)
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"], # Allows all origins. For production, change to your frontend URL.
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
-
23
- # 3. GLOBAL MODEL LOADER
24
- # We load the model once at startup so we don't reload it for every request.
25
- MODEL_PATH = "with_flux_model.keras"
26
- model = None
27
-
28
- @app.on_event("startup")
29
- async def load_ai_model():
30
- global model
31
- if os.path.exists(MODEL_PATH):
32
- print(f"πŸ”„ Loading Model from {MODEL_PATH}...")
33
- # compile=False makes it load faster/safer for inference only
34
- model = load_model(MODEL_PATH, compile=False)
35
- print(f"βœ… Model Loaded Successfully!")
36
- else:
37
- print(f"❌ CRITICAL ERROR: Model file '{MODEL_PATH}' not found.")
38
-
39
- # 4. PREPROCESSING FUNCTION (The "Perfect" Match)
40
- def preprocess_image(image_bytes):
41
- """
42
- Exact replica of the training preprocessing.
43
- 1. Open Image
44
- 2. Convert to RGB (Fixes PNG/Transparency issues)
45
- 3. Resize to 224x224 using LANCZOS (High quality)
46
- 4. Apply EfficientNetV2 specific preprocessing
47
- """
48
- try:
49
- # Open image from bytes
50
- img = Image.open(io.BytesIO(image_bytes))
51
-
52
- # Force RGB
53
- img = img.convert('RGB')
54
-
55
- # Resize (LANCZOS is what we used in training)
56
- img = img.resize((224, 224), Image.Resampling.LANCZOS)
57
-
58
- # Convert to Array
59
- img_array = np.array(img)
60
-
61
- # Expand dims (1, 224, 224, 3)
62
- img_array = np.expand_dims(img_array, axis=0)
63
-
64
- # TF EfficientNet Preprocessing
65
- return preprocess_input(img_array)
66
-
67
- except Exception as e:
68
- raise ValueError(f"Image processing failed: {str(e)}")
69
-
70
- # 5. THE PREDICTION ENDPOINT
71
- @app.post("/predict")
72
- async def predict(file: UploadFile = File(...)):
73
- global model
74
-
75
- # Validation
76
- if not model:
77
- raise HTTPException(status_code=500, detail="Model not loaded")
78
-
79
- if file.content_type not in ["image/jpeg", "image/png", "image/jpg", "image/webp"]:
80
- raise HTTPException(status_code=400, detail="Invalid file type. Please upload JPG or PNG.")
81
-
82
- try:
83
- # Read file
84
- contents = await file.read()
85
-
86
- # Preprocess
87
- processed_image = preprocess_image(contents)
88
-
89
- # Predict
90
- prediction = model.predict(processed_image)
91
-
92
- # Classes (Alphabetical Order used by Keras: 0=AI, 1=Real)
93
- ai_score = float(prediction[0][0])
94
- real_score = float(prediction[0][1])
95
-
96
- # Logic
97
- result_class = "AI" if ai_score > real_score else "Real"
98
- confidence = max(ai_score, real_score) * 100
99
-
100
- return {
101
- "prediction": result_class,
102
- "confidence_percentage": round(confidence, 2),
103
- "probabilities": {
104
- "ai": round(ai_score, 4),
105
- "real": round(real_score, 4)
106
- },
107
- "status": "success"
108
- }
109
-
110
- except Exception as e:
111
- return {"status": "error", "detail": str(e)}
112
-
113
- # Health Check
114
- @app.get("/")
115
- def home():
116
- return {"status": "online", "message": "Veritas AI Detector is running."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
6
+ from PIL import Image
7
+ import numpy as np
8
+ import cv2
9
+ import io
10
+ import os
11
+ import base64
12
+ import tempfile
13
+
14
+ # ==========================================
15
+ # 1. INITIAL SETUP
16
+ # ==========================================
17
+ app = FastAPI(title="Veritas Forensic Engine")
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ MODEL_PATH = "with_flux_model.keras"
28
+ model = None
29
+
30
+ # We need to find the last convolutional layer for GradCAM
31
+ # For EfficientNetV2B0, it is usually 'top_activation'
32
+ LAST_CONV_LAYER_NAME = "top_activation"
33
+
34
+ @app.on_event("startup")
35
+ async def load_ai_model():
36
+ global model
37
+ if os.path.exists(MODEL_PATH):
38
+ print(f"πŸ”„ Loading Forensic Model...")
39
+ model = load_model(MODEL_PATH, compile=False)
40
+ print(f"βœ… Model Loaded. Ready for Image & Video Analysis.")
41
+ else:
42
+ print(f"❌ CRITICAL: {MODEL_PATH} not found.")
43
+
44
+ # ==========================================
45
+ # 2. CORE PROCESSING (Sanitization)
46
+ # ==========================================
47
+ def preprocess_image_array(img_array):
48
+ """
49
+ Standardizes input for the model.
50
+ Expects numpy array (RGB). Returns batch tensor.
51
+ """
52
+ # Resize using TensorFlow to match training exactly
53
+ img_resized = tf.image.resize(img_array, (224, 224), method='lanczos3')
54
+ img_expanded = tf.expand_dims(img_resized, axis=0)
55
+ return preprocess_input(img_expanded)
56
+
57
+ def make_gradcam_heatmap(img_array, pred_index=None):
58
+ """
59
+ Generates the 'X-Ray' Heatmap using gradients.
60
+ """
61
+ # Create a model that maps the input image to the activations of the last conv layer
62
+ # as well as the output predictions
63
+ grad_model = tf.keras.models.Model(
64
+ [model.inputs],
65
+ [model.get_layer(LAST_CONV_LAYER_NAME).output, model.output]
66
+ )
67
+
68
+ with tf.GradientTape() as tape:
69
+ last_conv_layer_output, preds = grad_model(img_array)
70
+ if pred_index is None:
71
+ pred_index = tf.argmax(preds[0])
72
+ class_channel = preds[:, pred_index]
73
+
74
+ # Compute the gradient of the output neuron (class) with regard to the output feature map
75
+ grads = tape.gradient(class_channel, last_conv_layer_output)
76
+
77
+ # Average gradients spatially
78
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
79
+
80
+ # Weight the output feature map by the importance of the class
81
+ last_conv_layer_output = last_conv_layer_output[0]
82
+ heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
83
+ heatmap = tf.squeeze(heatmap)
84
+
85
+ # Normalize the heatmap between 0 and 1
86
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
87
+ return heatmap.numpy()
88
+
89
+ def overlay_heatmap(original_img_pil, heatmap):
90
+ """
91
+ Merges the Heatmap with the original image for display.
92
+ """
93
+ # Convert PIL to Array
94
+ img = np.array(original_img_pil)
95
+
96
+ # Rescale heatmap to a range 0-255
97
+ heatmap = np.uint8(255 * heatmap)
98
+
99
+ # Use jet colormap to colorize heatmap
100
+ jet = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
101
+
102
+ # Resize heatmap to match original image size
103
+ jet = cv2.resize(jet, (img.shape[1], img.shape[0]))
104
+
105
+ # Superimpose the heatmap on original image
106
+ # 0.4 = Heatmap intensity, 0.6 = Original Image intensity
107
+ superimposed_img = jet * 0.4 + img * 0.6
108
+ superimposed_img = np.clip(superimposed_img, 0, 255).astype("uint8")
109
+
110
+ # Encode back to Base64 to send to Frontend
111
+ is_success, buffer = cv2.imencode(".jpg", superimposed_img)
112
+ return base64.b64encode(buffer).decode("utf-8")
113
+
114
+ # ==========================================
115
+ # 3. ENDPOINT: IMAGE FORENSICS
116
+ # ==========================================
117
+ @app.post("/api/analyze-image")
118
+ async def analyze_image(file: UploadFile = File(...)):
119
+ if not model: raise HTTPException(500, "Model loading...")
120
+
121
+ try:
122
+ # 1. Read & Sanitize
123
+ contents = await file.read()
124
+ img = Image.open(io.BytesIO(contents)).convert('RGB')
125
+
126
+ # 2. Prepare for Model
127
+ processed_tensor = preprocess_image_array(np.array(img))
128
+
129
+ # 3. Predict
130
+ preds = model.predict(processed_tensor)
131
+ ai_score = float(preds[0][0])
132
+ real_score = float(preds[0][1])
133
+
134
+ confidence = max(ai_score, real_score) * 100
135
+ label = "AI" if ai_score > real_score else "Real"
136
+
137
+ # 4. Generate Heatmap (Only if AI is suspected, or always if you prefer)
138
+ # We generate it always so the user can toggle it
139
+ heatmap = make_gradcam_heatmap(processed_tensor)
140
+ heatmap_b64 = overlay_heatmap(img, heatmap)
141
+
142
+ return {
143
+ "type": "image",
144
+ "prediction": label,
145
+ "confidence": round(confidence, 2),
146
+ "heatmap_base64": heatmap_b64, # The visual proof
147
+ "probabilities": {"ai": ai_score, "real": real_score}
148
+ }
149
+
150
+ except Exception as e:
151
+ return {"error": str(e)}
152
+
153
+ # ==========================================
154
+ # 4. ENDPOINT: VIDEO SENTINEL
155
+ # ==========================================
156
+ @app.post("/api/analyze-video")
157
+ async def analyze_video(file: UploadFile = File(...)):
158
+ if not model: raise HTTPException(500, "Model loading...")
159
+
160
+ try:
161
+ # 1. Save Video Temporarily (OpenCV needs a file path)
162
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_vid:
163
+ temp_vid.write(await file.read())
164
+ temp_path = temp_vid.name
165
+
166
+ # 2. Process Video
167
+ cap = cv2.VideoCapture(temp_path)
168
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
169
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
170
+ duration = frame_count / fps if fps > 0 else 0
171
+
172
+ # We will extract 10 Keyframes evenly spaced
173
+ frames_to_analyze = 10
174
+ step = max(1, frame_count // frames_to_analyze)
175
+
176
+ timeline_results = []
177
+ fake_frame_count = 0
178
+
179
+ for i in range(0, frame_count, step):
180
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
181
+ ret, frame = cap.read()
182
+ if not ret: break
183
+
184
+ # Convert BGR (OpenCV) to RGB (Model)
185
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
186
+
187
+ # Preprocess & Predict
188
+ processed_tensor = preprocess_image_array(frame_rgb)
189
+ preds = model.predict(processed_tensor)
190
+ ai_score = float(preds[0][0])
191
+
192
+ # Timestamp
193
+ timestamp = round(i / fps, 2) if fps > 0 else 0
194
+
195
+ timeline_results.append({
196
+ "timestamp": timestamp,
197
+ "ai_score": ai_score,
198
+ "status": "FAKE" if ai_score > 0.5 else "REAL"
199
+ })
200
+
201
+ if ai_score > 0.5: fake_frame_count += 1
202
+
203
+ if len(timeline_results) >= frames_to_analyze: break
204
+
205
+ cap.release()
206
+ os.unlink(temp_path) # Delete temp file
207
+
208
+ # 3. Final Video Verdict
209
+ overall_fake_percent = (fake_frame_count / len(timeline_results)) * 100
210
+ final_verdict = "DEEPFAKE DETECTED" if overall_fake_percent > 40 else "AUTHENTIC VIDEO"
211
+
212
+ return {
213
+ "type": "video",
214
+ "prediction": final_verdict,
215
+ "fake_percentage": round(overall_fake_percent, 2),
216
+ "timeline": timeline_results # List of frame analysis
217
+ }
218
+
219
+ except Exception as e:
220
+ return {"error": str(e)}