File size: 16,898 Bytes
2a08a48
1ccc2d4
7d22b09
1ccc2d4
 
2a08a48
28f65c2
4215391
7d22b09
0c0bbf3
1ccc2d4
28f65c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a08a48
1ccc2d4
 
 
 
 
28f65c2
0c0bbf3
 
 
 
 
28f65c2
b5bf3aa
 
 
 
 
 
4215391
0c0bbf3
2a08a48
1ccc2d4
 
 
0c0bbf3
4215391
097628e
f90165a
b6f1be7
28f65c2
 
 
 
 
4215391
097628e
4215391
097628e
 
 
b6f1be7
 
4215391
097628e
 
2a08a48
097628e
 
 
 
b6f1be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0ce7d9
b6f1be7
a0ce7d9
f90165a
4215391
 
 
 
 
f90165a
2a08a48
f90165a
2a08a48
4215391
b6f1be7
4215391
28f65c2
 
 
 
 
 
2a08a48
4215391
 
 
b6f1be7
4215391
 
 
 
 
 
 
b6f1be7
 
4215391
b6f1be7
4215391
b6f1be7
b5bf3aa
4215391
28f65c2
b6f1be7
4215391
28f65c2
 
 
 
 
4215391
 
 
b6f1be7
4215391
 
 
2a08a48
b6f1be7
4215391
 
b6f1be7
 
f90165a
0c0bbf3
4215391
 
 
b6f1be7
4215391
 
 
 
b6f1be7
4215391
 
 
 
 
28f65c2
b6f1be7
4215391
 
 
 
b6f1be7
4215391
 
2a08a48
4215391
 
 
b6f1be7
4215391
 
901e54d
1ccc2d4
28f65c2
1ccc2d4
4215391
 
28f65c2
b6f1be7
4215391
 
b6f1be7
4215391
1ccc2d4
f90165a
1ccc2d4
7d22b09
28f65c2
4215391
 
 
2a08a48
4215391
 
 
 
b6f1be7
 
f90165a
b6f1be7
 
4215391
 
 
b6f1be7
68f9c28
28f65c2
68f9c28
28f65c2
 
6f206ca
 
28f65c2
 
 
4215391
28f65c2
b6f1be7
28f65c2
b6f1be7
 
 
28f65c2
b6f1be7
 
 
28f65c2
b6f1be7
 
4215391
28f65c2
 
b6f1be7
 
 
28f65c2
6f206ca
b6f1be7
4215391
 
28f65c2
 
 
4215391
 
28f65c2
4215391
 
 
 
28f65c2
b6f1be7
 
 
4215391
28f65c2
b6f1be7
 
28f65c2
 
b6f1be7
 
28f65c2
4215391
28f65c2
b6f1be7
 
4215391
 
 
28f65c2
b6f1be7
 
4215391
28f65c2
b6f1be7
28f65c2
b6f1be7
 
2a08a48
28f65c2
b6f1be7
 
 
 
 
28f65c2
b6f1be7
 
 
28f65c2
4215391
097628e
a5bc248
28f65c2
a5bc248
28f65c2
4215391
28f65c2
4215391
28f65c2
 
b6f1be7
28f65c2
 
 
f90165a
4215391
28f65c2
b6f1be7
4215391
 
 
 
 
28f65c2
4215391
 
28f65c2
4215391
28f65c2
 
4215391
 
 
 
2a08a48
28f65c2
 
 
 
 
 
 
4215391
 
28f65c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a08a48
 
b6f1be7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
# --- Standard Libraries ---
import logging
import atexit
import tempfile
import os
import hashlib
import shutil
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional
from pathlib import Path

# ==========================================
# πŸš€ FAST STARTUP: LOCAL WEIGHTS SETUP
# ==========================================
def setup_local_weights():
    os.environ["DEEPFACE_HOME"] = os.getcwd()
    target_dir = os.path.join(os.getcwd(), ".deepface", "weights")
    os.makedirs(target_dir, exist_ok=True)
    weight_file = "facenet512_weights.h5"
    target_path = os.path.join(target_dir, weight_file)

    if os.path.exists(weight_file) and not os.path.exists(target_path):
        print(f"πŸ“¦ Found local weights: {weight_file}. Installing...")
        shutil.copy(weight_file, target_path) # Use copy so original stays visible in files
    elif os.path.exists(target_path):
        print("βœ… Weights already installed.")
    else:
        print("⚠️ Local weights not found. DeepFace might download them.")

# RUN THIS IMMEDIATELY BEFORE IMPORTS
setup_local_weights()

# --- Computer Vision & UI Libraries ---
import cv2
import numpy as np
import gradio as gr
from ultralytics import YOLO

# --- Face Recognition Libraries ---
try:
    from deepface import DeepFace
    DEEPFACE_AVAILABLE = True
except ImportError:
    DEEPFACE_AVAILABLE = False
    logging.warning("⚠️ DeepFace not installed - Recognition disabled.")

try:
    import chromadb
    CHROMADB_AVAILABLE = True
except ImportError:
    CHROMADB_AVAILABLE = False
    logging.warning("⚠️ ChromaDB not installed - Database features disabled.")

# --- Configure Logging ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# ====================================================
# 1. CONFIGURATION & UTILITIES
# ====================================================

# --- TUNED PARAMETERS ---
DETECTION_CONFIDENCE = 0.4   # Moderate strictness to reduce false boxes
RECOGNITION_THRESHOLD = 0.30 # STRICT: Prevents Identity Confusion
TARGET_MOSAIC_GRID = 10      # Pixelation Grid Size
MIN_PIXEL_SIZE = 12          # Min block size (prevents weak blur on small faces)
COVERAGE_SCALE = 1.2         # 120% padding (Captures hair/chin)

TEMP_FILES = []

def cleanup_temp_files():
    for f in TEMP_FILES:
        try:
            if os.path.exists(f): os.remove(f)
        except Exception: pass

atexit.register(cleanup_temp_files)

def create_temp_file(suffix=".mp4") -> str:
    path = tempfile.mktemp(suffix=suffix)
    TEMP_FILES.append(path)
    return path

def get_padded_coords(image, x, y, w, h, scale=COVERAGE_SCALE):
    """
    UNIFIED COORDINATE SYSTEM:
    Calculates the padded coordinates once so Blur and Box match perfectly.
    """
    h_img, w_img = image.shape[:2]
    pad_w = int(w * (scale - 1.0) / 2)
    pad_h = int(h * (scale - 1.0) / 2)
    
    new_x = max(0, x - pad_w)
    new_y = max(0, y - pad_h)
    new_w = min(w_img - new_x, w + (2 * pad_w))
    new_h = min(h_img - new_y, h + (2 * pad_h))
    
    return new_x, new_y, new_w, new_h

# ====================================================
# 2. THE DATABASE LAYER
# ====================================================
class FaceDatabase:
    def __init__(self, db_path="./chroma_db", faces_dir="known_faces"):
        self.faces_dir = Path(faces_dir)
        self.collection = None
        self.is_active = False
        
        if not DEEPFACE_AVAILABLE or not CHROMADB_AVAILABLE:
            return

        try:
            self.client = chromadb.PersistentClient(path=db_path)
            self.collection = self.client.get_or_create_collection(name="face_embeddings", metadata={"hnsw:space": "cosine"})
            self.is_active = True
            
            if self.faces_dir.exists():
                self._scan_and_index()
            else:
                self.faces_dir.mkdir(parents=True, exist_ok=True)
                
        except Exception as e:
            logger.error(f"❌ DB Init Error: {e}")

    def _get_hash(self, img_path: Path) -> str:
        with open(img_path, 'rb') as f: return hashlib.md5(f.read()).hexdigest()

    def _scan_and_index(self):
        logger.info("πŸ”„ Scanning 'known_faces' folder...")
        for person_dir in self.faces_dir.iterdir():
            if not person_dir.is_dir(): continue
            
            parts = person_dir.name.split('_', 1)
            p_id = parts[0] if len(parts) > 1 else "000"
            p_name = parts[1].replace('_', ' ') if len(parts) > 1 else person_dir.name
            
            images = list(person_dir.glob("*.*"))
            for img_path in images:
                if img_path.suffix.lower() not in ['.jpg', '.png', '.webp', '.jpeg']: continue
                try:
                    img_hash = self._get_hash(img_path)
                    # Skip if already indexed
                    if self.collection.get(ids=[img_hash])['ids']: continue
                    
                    embedding_objs = DeepFace.represent(
                        img_path=str(img_path), 
                        model_name="Facenet512", 
                        enforce_detection=False
                    )
                    if embedding_objs:
                        self.collection.add(
                            ids=[img_hash],
                            embeddings=[embedding_objs[0]["embedding"]],
                            metadatas=[{"id": p_id, "name": p_name, "file": img_path.name}]
                        )
                        logger.info(f"βœ… Indexed: {p_name}")
                except Exception as e:
                    logger.error(f"⚠️ Skip {img_path.name}: {e}")

    def recognize(self, face_img: np.ndarray) -> Dict[str, Any]:
        default = {"match": False, "name": "Unknown", "id": "Unknown", "color": (255, 0, 0)} # Red
        if not self.is_active or self.collection.count() == 0: return default

        try:
            temp_path = "temp_query.jpg"
            cv2.imwrite(temp_path, cv2.cvtColor(face_img, cv2.COLOR_RGB2BGR))
            
            embedding_objs = DeepFace.represent(img_path=temp_path, model_name="Facenet512", enforce_detection=False)
            if os.path.exists(temp_path): os.remove(temp_path)

            if not embedding_objs: return default

            results = self.collection.query(query_embeddings=[embedding_objs[0]["embedding"]], n_results=1)
            if not results['ids'][0]: return default

            distance = results['distances'][0][0]
            metadata = results['metadatas'][0][0]

            # STRICT THRESHOLD APPLIED HERE
            if distance < RECOGNITION_THRESHOLD:
                return {
                    "match": True,
                    "name": metadata['name'],
                    "id": metadata['id'],
                    "color": (0, 255, 0) # Green
                }
            return default
        except Exception as e:
            return default

    def get_stats(self):
        return f"βœ… Active | {self.collection.count()} Faces" if (self.is_active and self.collection) else "❌ Offline"

FACE_DB = FaceDatabase()

# ====================================================
# 3. DETECTOR & PROCESSING LOGIC
# ====================================================
class Detector:
    def __init__(self):
        # Uses the local pt file if available
        self.model = YOLO("yolov8n-face.pt")
    
    def detect(self, image: np.ndarray):
        results = self.model(image, conf=DETECTION_CONFIDENCE, verbose=False)
        faces = []
        for r in results:
            if r.boxes is None: continue
            for box in r.boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0])
                faces.append((x1, y1, x2-x1, y2-y1))
        return faces

GLOBAL_DETECTOR = Detector()

def apply_blur(image, x, y, w, h):
    roi = image[y:y+h, x:x+w]
    if roi.size == 0: return image
    
    grid_pixel_limit = max(1, w // MIN_PIXEL_SIZE)
    final_grid_size = max(2, min(TARGET_MOSAIC_GRID, grid_pixel_limit))
    
    small = cv2.resize(roi, (final_grid_size, final_grid_size), interpolation=cv2.INTER_LINEAR)
    pixelated = cv2.resize(small, (w, h), interpolation=cv2.INTER_NEAREST)
    image[y:y+h, x:x+w] = pixelated
    return image

def draw_smart_label(image, x, y, w, h, text, color):
    """
    UX FIX: Draws label OUTSIDE the box (Header/Footer style)
    """
    # 1. Draw Box
    thickness = 2 if w > 40 else 1
    cv2.rectangle(image, (x, y), (x+w, y+h), color, thickness)

    if not text: return

    # 2. Measure Text
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.6
    font_thick = 2
    (tw, th), _ = cv2.getTextSize(text, font, font_scale, font_thick)

    # 3. Smart Positioning (Top vs Bottom)
    text_y = y - 10 
    # If close to top edge, flip to bottom
    if y - th - 15 < 0: 
        text_y = y + h + th + 10

    # Center horizontally
    center_x = x + (w // 2)
    text_x = center_x - (tw // 2)
    
    # Background Box
    pad = 4
    cv2.rectangle(image, 
                 (text_x - pad, text_y - th - pad), 
                 (text_x + tw + pad, text_y + pad), 
                 color, -1)

    cv2.putText(image, text, (text_x, text_y), font, font_scale, (255, 255, 255), font_thick, cv2.LINE_AA)

def process_frame(image, mode):
    """
    MASTER FUNCTION
    """
    if image is None: return None, "No Image"
    
    # Detect
    faces = GLOBAL_DETECTOR.detect(image)
    processed_img = image.copy()
    log_entries = []
    
    # Queue drawing to ensure labels are ON TOP of blur
    draw_queue = [] 

    for i, (raw_x, raw_y, raw_w, raw_h) in enumerate(faces):
        
        # 1. UNIFIED COORDINATES (Fixes "Bleeding" Blur)
        px, py, pw, ph = get_padded_coords(processed_img, raw_x, raw_y, raw_w, raw_h)

        label_text = ""
        box_color = (200, 0, 0) # Default Red
        log_text = "Unknown"

        # 2. ANALYSIS (Data/Smart Mode)
        if mode in ["data", "smart"]:
            # Crop using padded coords for better context
            face_crop = processed_img[py:py+ph, px:px+pw]
            
            if face_crop.size > 0:
                res = FACE_DB.recognize(face_crop)
                if res['match']:
                    label_text = f"ID: {res['id']}"
                    box_color = (0, 200, 0) # Green
                    log_text = f"MATCH: {res['name']}"
                else:
                    label_text = "Unknown"
                    log_text = "Unknown Person"
            
            draw_queue.append((px, py, pw, ph, label_text, box_color))
            log_entries.append(f"Face #{i+1}: {log_text}")

        # 3. MODIFICATION (Privacy/Smart Mode)
        if mode in ["privacy", "smart"]:
            processed_img = apply_blur(processed_img, px, py, pw, ph)
            if mode == "privacy":
                log_entries.append(f"Face #{i+1}: Redacted")

    # 4. DRAW UI (Top Layer)
    for (dx, dy, dw, dh, txt, col) in draw_queue:
        draw_smart_label(processed_img, dx, dy, dw, dh, txt, col)

    final_log = "--- Detection Report ---\n" + "\n".join(log_entries) if log_entries else "No faces detected."
    return processed_img, final_log

# ====================================================
# 4. VIDEO PROCESSING
# ====================================================
def process_video_general(video_path, mode, progress=gr.Progress()):
    if not video_path: return None
    
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened(): return None

    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    out_path = create_temp_file()
    # Try mp4v for better compatibility
    out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    
    cnt = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break
        
        res_frame, _ = process_frame(frame, mode)
        out.write(res_frame)
        
        cnt += 1
        if total > 0 and cnt % 5 == 0: 
            progress(cnt/total, desc=f"Processing Frame {cnt}/{total}")
        
    cap.release()
    out.release()
    return out_path

# ====================================================
# 5. GRADIO INTERFACE (FULL)
# ====================================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="Smart Redaction Pro") as demo:
    
    gr.Markdown("# πŸ›‘οΈ Smart Redaction System (Enterprise Patch)")
    gr.Markdown(f"**Status:** {FACE_DB.get_stats()} | **Strictness:** {RECOGNITION_THRESHOLD}")
    
    with gr.Tabs():
        
        # --- TAB 1: RAW PRIVACY ---
        with gr.TabItem("1️⃣ Raw Privacy"):
            gr.Markdown("### πŸ”’ Total Anonymization")
            with gr.Tabs():
                with gr.TabItem("Image"):
                    p_img_in = gr.Image(label="Input", type="numpy", height=400)
                    p_img_out = gr.Image(label="Output", height=400)
                    p_btn = gr.Button("Apply Privacy", variant="primary")
                with gr.TabItem("Video"):
                    p_vid_in = gr.Video(label="Input Video")
                    p_vid_out = gr.Video(label="Output Video")
                    p_vid_btn = gr.Button("Process Video", variant="primary")
                with gr.TabItem("Webcam"):
                    p_web_in = gr.Image(sources=["webcam"], streaming=True, type="numpy")
                    p_web_out = gr.Image(label="Live Stream")

        # --- TAB 2: DATA LAYER ---
        with gr.TabItem("2️⃣ Security Data"):
            gr.Markdown("### πŸ” Recognition (No Blur)")
            with gr.Tabs():
                with gr.TabItem("Image"):
                    with gr.Row():
                        d_img_in = gr.Image(label="Input", type="numpy", height=400)
                        with gr.Column():
                            d_img_out = gr.Image(label="Output", height=400)
                            d_log_out = gr.Textbox(label="Logs", lines=4)
                    d_btn = gr.Button("Analyze", variant="primary")
                with gr.TabItem("Video"):
                    d_vid_in = gr.Video(label="Input Video")
                    d_vid_out = gr.Video(label="Output Video")
                    d_vid_btn = gr.Button("Analyze Video", variant="primary")
                with gr.TabItem("Webcam"):
                    d_web_in = gr.Image(sources=["webcam"], streaming=True, type="numpy")
                    d_web_out = gr.Image(label="Live Data Stream")

        # --- TAB 3: SMART REDACTION ---
        with gr.TabItem("3️⃣ Smart Redaction"):
            gr.Markdown("### πŸ›‘οΈ Identity + Privacy")
            with gr.Tabs():
                with gr.TabItem("Image"):
                    with gr.Row():
                        s_img_in = gr.Image(label="Input", type="numpy", height=400)
                        with gr.Column():
                            s_img_out = gr.Image(label="Output", height=400)
                            s_log_out = gr.Textbox(label="Logs", lines=4)
                    s_btn = gr.Button("Apply Smart Redaction", variant="primary")
                with gr.TabItem("Video"):
                    s_vid_in = gr.Video(label="Input Video")
                    s_vid_out = gr.Video(label="Output Video")
                    s_vid_btn = gr.Button("Process Smart Video", variant="primary")
                with gr.TabItem("Webcam"):
                    s_web_in = gr.Image(sources=["webcam"], streaming=True, type="numpy")
                    s_web_out = gr.Image(label="Live Smart Stream")

    # =========================================
    # WIRING
    # =========================================
    
    # Privacy
    p_btn.click(lambda img: process_frame(img, "privacy")[0], inputs=[p_img_in], outputs=p_img_out)
    p_vid_btn.click(lambda vid: process_video_general(vid, "privacy"), inputs=[p_vid_in], outputs=p_vid_out)
    p_web_in.stream(lambda img: process_frame(img, "privacy")[0], inputs=[p_web_in], outputs=p_web_out)

    # Data
    d_btn.click(lambda img: process_frame(img, "data"), inputs=[d_img_in], outputs=[d_img_out, d_log_out])
    d_vid_btn.click(lambda vid: process_video_general(vid, "data"), inputs=[d_vid_in], outputs=d_vid_out)
    d_web_in.stream(lambda img: process_frame(img, "data")[0], inputs=[d_web_in], outputs=d_web_out)

    # Smart
    s_btn.click(lambda img: process_frame(img, "smart"), inputs=[s_img_in], outputs=[s_img_out, s_log_out])
    s_vid_btn.click(lambda vid: process_video_general(vid, "smart"), inputs=[s_vid_in], outputs=s_vid_out)
    s_web_in.stream(lambda img: process_frame(img, "smart")[0], inputs=[s_web_in], outputs=s_web_out)

if __name__ == "__main__":
    demo.launch()