test_space1 / app.py
ab2207's picture
Update app.py
28f65c2 verified
# --- Standard Libraries ---
import logging
import atexit
import tempfile
import os
import hashlib
import shutil
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional
from pathlib import Path
# ==========================================
# πŸš€ FAST STARTUP: LOCAL WEIGHTS SETUP
# ==========================================
def setup_local_weights():
os.environ["DEEPFACE_HOME"] = os.getcwd()
target_dir = os.path.join(os.getcwd(), ".deepface", "weights")
os.makedirs(target_dir, exist_ok=True)
weight_file = "facenet512_weights.h5"
target_path = os.path.join(target_dir, weight_file)
if os.path.exists(weight_file) and not os.path.exists(target_path):
print(f"πŸ“¦ Found local weights: {weight_file}. Installing...")
shutil.copy(weight_file, target_path) # Use copy so original stays visible in files
elif os.path.exists(target_path):
print("βœ… Weights already installed.")
else:
print("⚠️ Local weights not found. DeepFace might download them.")
# RUN THIS IMMEDIATELY BEFORE IMPORTS
setup_local_weights()
# --- Computer Vision & UI Libraries ---
import cv2
import numpy as np
import gradio as gr
from ultralytics import YOLO
# --- Face Recognition Libraries ---
try:
from deepface import DeepFace
DEEPFACE_AVAILABLE = True
except ImportError:
DEEPFACE_AVAILABLE = False
logging.warning("⚠️ DeepFace not installed - Recognition disabled.")
try:
import chromadb
CHROMADB_AVAILABLE = True
except ImportError:
CHROMADB_AVAILABLE = False
logging.warning("⚠️ ChromaDB not installed - Database features disabled.")
# --- Configure Logging ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
# ====================================================
# 1. CONFIGURATION & UTILITIES
# ====================================================
# --- TUNED PARAMETERS ---
DETECTION_CONFIDENCE = 0.4 # Moderate strictness to reduce false boxes
RECOGNITION_THRESHOLD = 0.30 # STRICT: Prevents Identity Confusion
TARGET_MOSAIC_GRID = 10 # Pixelation Grid Size
MIN_PIXEL_SIZE = 12 # Min block size (prevents weak blur on small faces)
COVERAGE_SCALE = 1.2 # 120% padding (Captures hair/chin)
TEMP_FILES = []
def cleanup_temp_files():
for f in TEMP_FILES:
try:
if os.path.exists(f): os.remove(f)
except Exception: pass
atexit.register(cleanup_temp_files)
def create_temp_file(suffix=".mp4") -> str:
path = tempfile.mktemp(suffix=suffix)
TEMP_FILES.append(path)
return path
def get_padded_coords(image, x, y, w, h, scale=COVERAGE_SCALE):
"""
UNIFIED COORDINATE SYSTEM:
Calculates the padded coordinates once so Blur and Box match perfectly.
"""
h_img, w_img = image.shape[:2]
pad_w = int(w * (scale - 1.0) / 2)
pad_h = int(h * (scale - 1.0) / 2)
new_x = max(0, x - pad_w)
new_y = max(0, y - pad_h)
new_w = min(w_img - new_x, w + (2 * pad_w))
new_h = min(h_img - new_y, h + (2 * pad_h))
return new_x, new_y, new_w, new_h
# ====================================================
# 2. THE DATABASE LAYER
# ====================================================
class FaceDatabase:
def __init__(self, db_path="./chroma_db", faces_dir="known_faces"):
self.faces_dir = Path(faces_dir)
self.collection = None
self.is_active = False
if not DEEPFACE_AVAILABLE or not CHROMADB_AVAILABLE:
return
try:
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self.client.get_or_create_collection(name="face_embeddings", metadata={"hnsw:space": "cosine"})
self.is_active = True
if self.faces_dir.exists():
self._scan_and_index()
else:
self.faces_dir.mkdir(parents=True, exist_ok=True)
except Exception as e:
logger.error(f"❌ DB Init Error: {e}")
def _get_hash(self, img_path: Path) -> str:
with open(img_path, 'rb') as f: return hashlib.md5(f.read()).hexdigest()
def _scan_and_index(self):
logger.info("πŸ”„ Scanning 'known_faces' folder...")
for person_dir in self.faces_dir.iterdir():
if not person_dir.is_dir(): continue
parts = person_dir.name.split('_', 1)
p_id = parts[0] if len(parts) > 1 else "000"
p_name = parts[1].replace('_', ' ') if len(parts) > 1 else person_dir.name
images = list(person_dir.glob("*.*"))
for img_path in images:
if img_path.suffix.lower() not in ['.jpg', '.png', '.webp', '.jpeg']: continue
try:
img_hash = self._get_hash(img_path)
# Skip if already indexed
if self.collection.get(ids=[img_hash])['ids']: continue
embedding_objs = DeepFace.represent(
img_path=str(img_path),
model_name="Facenet512",
enforce_detection=False
)
if embedding_objs:
self.collection.add(
ids=[img_hash],
embeddings=[embedding_objs[0]["embedding"]],
metadatas=[{"id": p_id, "name": p_name, "file": img_path.name}]
)
logger.info(f"βœ… Indexed: {p_name}")
except Exception as e:
logger.error(f"⚠️ Skip {img_path.name}: {e}")
def recognize(self, face_img: np.ndarray) -> Dict[str, Any]:
default = {"match": False, "name": "Unknown", "id": "Unknown", "color": (255, 0, 0)} # Red
if not self.is_active or self.collection.count() == 0: return default
try:
temp_path = "temp_query.jpg"
cv2.imwrite(temp_path, cv2.cvtColor(face_img, cv2.COLOR_RGB2BGR))
embedding_objs = DeepFace.represent(img_path=temp_path, model_name="Facenet512", enforce_detection=False)
if os.path.exists(temp_path): os.remove(temp_path)
if not embedding_objs: return default
results = self.collection.query(query_embeddings=[embedding_objs[0]["embedding"]], n_results=1)
if not results['ids'][0]: return default
distance = results['distances'][0][0]
metadata = results['metadatas'][0][0]
# STRICT THRESHOLD APPLIED HERE
if distance < RECOGNITION_THRESHOLD:
return {
"match": True,
"name": metadata['name'],
"id": metadata['id'],
"color": (0, 255, 0) # Green
}
return default
except Exception as e:
return default
def get_stats(self):
return f"βœ… Active | {self.collection.count()} Faces" if (self.is_active and self.collection) else "❌ Offline"
FACE_DB = FaceDatabase()
# ====================================================
# 3. DETECTOR & PROCESSING LOGIC
# ====================================================
class Detector:
def __init__(self):
# Uses the local pt file if available
self.model = YOLO("yolov8n-face.pt")
def detect(self, image: np.ndarray):
results = self.model(image, conf=DETECTION_CONFIDENCE, verbose=False)
faces = []
for r in results:
if r.boxes is None: continue
for box in r.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
faces.append((x1, y1, x2-x1, y2-y1))
return faces
GLOBAL_DETECTOR = Detector()
def apply_blur(image, x, y, w, h):
roi = image[y:y+h, x:x+w]
if roi.size == 0: return image
grid_pixel_limit = max(1, w // MIN_PIXEL_SIZE)
final_grid_size = max(2, min(TARGET_MOSAIC_GRID, grid_pixel_limit))
small = cv2.resize(roi, (final_grid_size, final_grid_size), interpolation=cv2.INTER_LINEAR)
pixelated = cv2.resize(small, (w, h), interpolation=cv2.INTER_NEAREST)
image[y:y+h, x:x+w] = pixelated
return image
def draw_smart_label(image, x, y, w, h, text, color):
"""
UX FIX: Draws label OUTSIDE the box (Header/Footer style)
"""
# 1. Draw Box
thickness = 2 if w > 40 else 1
cv2.rectangle(image, (x, y), (x+w, y+h), color, thickness)
if not text: return
# 2. Measure Text
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
font_thick = 2
(tw, th), _ = cv2.getTextSize(text, font, font_scale, font_thick)
# 3. Smart Positioning (Top vs Bottom)
text_y = y - 10
# If close to top edge, flip to bottom
if y - th - 15 < 0:
text_y = y + h + th + 10
# Center horizontally
center_x = x + (w // 2)
text_x = center_x - (tw // 2)
# Background Box
pad = 4
cv2.rectangle(image,
(text_x - pad, text_y - th - pad),
(text_x + tw + pad, text_y + pad),
color, -1)
cv2.putText(image, text, (text_x, text_y), font, font_scale, (255, 255, 255), font_thick, cv2.LINE_AA)
def process_frame(image, mode):
"""
MASTER FUNCTION
"""
if image is None: return None, "No Image"
# Detect
faces = GLOBAL_DETECTOR.detect(image)
processed_img = image.copy()
log_entries = []
# Queue drawing to ensure labels are ON TOP of blur
draw_queue = []
for i, (raw_x, raw_y, raw_w, raw_h) in enumerate(faces):
# 1. UNIFIED COORDINATES (Fixes "Bleeding" Blur)
px, py, pw, ph = get_padded_coords(processed_img, raw_x, raw_y, raw_w, raw_h)
label_text = ""
box_color = (200, 0, 0) # Default Red
log_text = "Unknown"
# 2. ANALYSIS (Data/Smart Mode)
if mode in ["data", "smart"]:
# Crop using padded coords for better context
face_crop = processed_img[py:py+ph, px:px+pw]
if face_crop.size > 0:
res = FACE_DB.recognize(face_crop)
if res['match']:
label_text = f"ID: {res['id']}"
box_color = (0, 200, 0) # Green
log_text = f"MATCH: {res['name']}"
else:
label_text = "Unknown"
log_text = "Unknown Person"
draw_queue.append((px, py, pw, ph, label_text, box_color))
log_entries.append(f"Face #{i+1}: {log_text}")
# 3. MODIFICATION (Privacy/Smart Mode)
if mode in ["privacy", "smart"]:
processed_img = apply_blur(processed_img, px, py, pw, ph)
if mode == "privacy":
log_entries.append(f"Face #{i+1}: Redacted")
# 4. DRAW UI (Top Layer)
for (dx, dy, dw, dh, txt, col) in draw_queue:
draw_smart_label(processed_img, dx, dy, dw, dh, txt, col)
final_log = "--- Detection Report ---\n" + "\n".join(log_entries) if log_entries else "No faces detected."
return processed_img, final_log
# ====================================================
# 4. VIDEO PROCESSING
# ====================================================
def process_video_general(video_path, mode, progress=gr.Progress()):
if not video_path: return None
cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): return None
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
out_path = create_temp_file()
# Try mp4v for better compatibility
out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
cnt = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret: break
res_frame, _ = process_frame(frame, mode)
out.write(res_frame)
cnt += 1
if total > 0 and cnt % 5 == 0:
progress(cnt/total, desc=f"Processing Frame {cnt}/{total}")
cap.release()
out.release()
return out_path
# ====================================================
# 5. GRADIO INTERFACE (FULL)
# ====================================================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue"), title="Smart Redaction Pro") as demo:
gr.Markdown("# πŸ›‘οΈ Smart Redaction System (Enterprise Patch)")
gr.Markdown(f"**Status:** {FACE_DB.get_stats()} | **Strictness:** {RECOGNITION_THRESHOLD}")
with gr.Tabs():
# --- TAB 1: RAW PRIVACY ---
with gr.TabItem("1️⃣ Raw Privacy"):
gr.Markdown("### πŸ”’ Total Anonymization")
with gr.Tabs():
with gr.TabItem("Image"):
p_img_in = gr.Image(label="Input", type="numpy", height=400)
p_img_out = gr.Image(label="Output", height=400)
p_btn = gr.Button("Apply Privacy", variant="primary")
with gr.TabItem("Video"):
p_vid_in = gr.Video(label="Input Video")
p_vid_out = gr.Video(label="Output Video")
p_vid_btn = gr.Button("Process Video", variant="primary")
with gr.TabItem("Webcam"):
p_web_in = gr.Image(sources=["webcam"], streaming=True, type="numpy")
p_web_out = gr.Image(label="Live Stream")
# --- TAB 2: DATA LAYER ---
with gr.TabItem("2️⃣ Security Data"):
gr.Markdown("### πŸ” Recognition (No Blur)")
with gr.Tabs():
with gr.TabItem("Image"):
with gr.Row():
d_img_in = gr.Image(label="Input", type="numpy", height=400)
with gr.Column():
d_img_out = gr.Image(label="Output", height=400)
d_log_out = gr.Textbox(label="Logs", lines=4)
d_btn = gr.Button("Analyze", variant="primary")
with gr.TabItem("Video"):
d_vid_in = gr.Video(label="Input Video")
d_vid_out = gr.Video(label="Output Video")
d_vid_btn = gr.Button("Analyze Video", variant="primary")
with gr.TabItem("Webcam"):
d_web_in = gr.Image(sources=["webcam"], streaming=True, type="numpy")
d_web_out = gr.Image(label="Live Data Stream")
# --- TAB 3: SMART REDACTION ---
with gr.TabItem("3️⃣ Smart Redaction"):
gr.Markdown("### πŸ›‘οΈ Identity + Privacy")
with gr.Tabs():
with gr.TabItem("Image"):
with gr.Row():
s_img_in = gr.Image(label="Input", type="numpy", height=400)
with gr.Column():
s_img_out = gr.Image(label="Output", height=400)
s_log_out = gr.Textbox(label="Logs", lines=4)
s_btn = gr.Button("Apply Smart Redaction", variant="primary")
with gr.TabItem("Video"):
s_vid_in = gr.Video(label="Input Video")
s_vid_out = gr.Video(label="Output Video")
s_vid_btn = gr.Button("Process Smart Video", variant="primary")
with gr.TabItem("Webcam"):
s_web_in = gr.Image(sources=["webcam"], streaming=True, type="numpy")
s_web_out = gr.Image(label="Live Smart Stream")
# =========================================
# WIRING
# =========================================
# Privacy
p_btn.click(lambda img: process_frame(img, "privacy")[0], inputs=[p_img_in], outputs=p_img_out)
p_vid_btn.click(lambda vid: process_video_general(vid, "privacy"), inputs=[p_vid_in], outputs=p_vid_out)
p_web_in.stream(lambda img: process_frame(img, "privacy")[0], inputs=[p_web_in], outputs=p_web_out)
# Data
d_btn.click(lambda img: process_frame(img, "data"), inputs=[d_img_in], outputs=[d_img_out, d_log_out])
d_vid_btn.click(lambda vid: process_video_general(vid, "data"), inputs=[d_vid_in], outputs=d_vid_out)
d_web_in.stream(lambda img: process_frame(img, "data")[0], inputs=[d_web_in], outputs=d_web_out)
# Smart
s_btn.click(lambda img: process_frame(img, "smart"), inputs=[s_img_in], outputs=[s_img_out, s_log_out])
s_vid_btn.click(lambda vid: process_video_general(vid, "smart"), inputs=[s_vid_in], outputs=s_vid_out)
s_web_in.stream(lambda img: process_frame(img, "smart")[0], inputs=[s_web_in], outputs=s_web_out)
if __name__ == "__main__":
demo.launch()