aefrss / app.py
midokhaled927's picture
Update app.py
2154dc4 verified
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import HTMLResponse, JSONResponse
import numpy as np
import cv2
import os
import json
import secrets
from datetime import datetime, timedelta
from typing import List, Dict, Optional, Tuple
import requests
app = FastAPI(title="AEFRS Face Recognition System", docs_url="/docs")
security = HTTPBearer()
# ========== إعداد المسارات ==========
os.makedirs("artifacts/vector_index", exist_ok=True)
os.makedirs("artifacts/metadata", exist_ok=True)
os.makedirs("artifacts/models", exist_ok=True)
# ========== تحميل النماذج ==========
detection_session = None
embedding_session = None
def download_model(url: str, save_path: str) -> bool:
"""تحميل نموذج من رابط مباشر"""
try:
print(f"📥 Downloading model from {url}...")
response = requests.get(url, stream=True, timeout=30)
if response.status_code == 200:
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"✅ Model saved to {save_path}")
return True
else:
print(f"❌ Download failed: HTTP {response.status_code}")
return False
except Exception as e:
print(f"❌ Download error: {e}")
return False
def load_onnx_models():
"""تحميل نماذج ONNX"""
global detection_session, embedding_session
# محاولة تحميل RetinaFace (نموذج خفيف)
retinaface_path = "artifacts/models/retinaface.onnx"
if os.path.exists(retinaface_path):
try:
import onnxruntime as ort
detection_session = ort.InferenceSession(retinaface_path)
print("✅ RetinaFace model loaded from file")
except Exception as e:
print(f"⚠️ Failed to load RetinaFace: {e}")
else:
print(f"⚠️ RetinaFace not found at {retinaface_path}")
# محاولة تحميل ArcFace
arcface_path = "artifacts/models/arcface_iresnet100.onnx"
# إذا لم يكن النموذج موجوداً، حاول تحميله
if not os.path.exists(arcface_path):
print("📥 ArcFace model not found, attempting to download...")
# رابط نموذج MobileFaceNet (خفيف - 5MB)
url = "https://github.com/leondgarse/Keras_insightface/releases/download/v1.0.0/mobilefacenet_128.onnx"
if download_model(url, arcface_path):
print("✅ ArcFace model downloaded successfully")
# تحميل النموذج إذا كان موجوداً
if os.path.exists(arcface_path):
try:
import onnxruntime as ort
embedding_session = ort.InferenceSession(arcface_path)
print("✅ ArcFace model loaded successfully")
except Exception as e:
print(f"⚠️ Failed to load ArcFace: {e}")
else:
print(f"⚠️ ArcFace model not available, using fallback mode")
# تحميل النماذج
load_onnx_models()
# ========== تخزين البيانات ==========
active_tokens = {}
identities_db = {}
vector_index = []
DATA_FILE = "artifacts/data.json"
def load_data():
global identities_db, vector_index
if os.path.exists(DATA_FILE):
try:
with open(DATA_FILE, 'r') as f:
data = json.load(f)
identities_db = data.get('identities', {})
vector_index = data.get('vector_index', [])
print(f"✅ Loaded {len(identities_db)} identities")
except Exception as e:
print(f"Error loading data: {e}")
def save_data():
try:
with open(DATA_FILE, 'w') as f:
json.dump({
'identities': identities_db,
'vector_index': vector_index
}, f, indent=2)
print(f"✅ Saved {len(identities_db)} identities")
except Exception as e:
print(f"Error saving data: {e}")
load_data()
# ========== دوال التوكن ==========
def generate_token(username: str) -> str:
token = secrets.token_urlsafe(32)
active_tokens[token] = {
"username": username,
"expires": datetime.now() + timedelta(hours=24)
}
return token
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
if token not in active_tokens:
raise HTTPException(status_code=401, detail="Invalid token")
token_data = active_tokens[token]
if datetime.now() > token_data["expires"]:
del active_tokens[token]
raise HTTPException(status_code=401, detail="Token expired")
return token_data["username"]
# ========== استخراج الميزات ==========
def extract_features_fallback(image_bytes) -> np.ndarray:
"""استخراج ميزات بديلة (تحسين للتمييز بين الأشخاص)"""
img_array = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
if img is None:
return None
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
resized = cv2.resize(gray, (128, 128))
# استخدام HOG مع إعدادات أفضل للتمييز
try:
from skimage.feature import hog
features = hog(resized, orientations=12, pixels_per_cell=(8, 8),
cells_per_block=(2, 2), visualize=False)
except:
# بديل بسيط إذا لم تكن skimage متاحة
features = resized.flatten() / 255.0
# تقليل الأبعاد
step = len(features) // 512
if step > 0:
features = features[::step][:512]
else:
features = np.pad(features, (0, 512 - len(features)))
# تطبيع
features = (features - features.mean()) / (features.std() + 1e-6)
return features.astype(np.float32)
def extract_face_embedding(image_bytes) -> np.ndarray:
"""استخراج embedding من الوجه"""
global embedding_session
# إذا كان ArcFace متاحاً، استخدمه
if embedding_session is not None:
try:
img_array = np.frombuffer(image_bytes, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
if img is None:
return extract_features_fallback(image_bytes)
# تجهيز الصورة لـ ArcFace
face = cv2.resize(img, (112, 112))
face = face.astype(np.float32)
face = (face - 127.5) / 128.0
face = np.transpose(face, (2, 0, 1))
face = np.expand_dims(face, axis=0)
# استخراج embedding
input_name = embedding_session.get_inputs()[0].name
embedding = embedding_session.run(None, {input_name: face})[0][0]
return embedding.astype(np.float32)
except Exception as e:
print(f"ArcFace error: {e}")
# استخدام الـ fallback
return extract_features_fallback(image_bytes)
def cosine_similarity(a, b):
"""حساب التشابه"""
a = np.array(a)
b = np.array(b)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
similarity = np.dot(a, b) / (norm_a * norm_b)
similarity = (similarity + 1) / 2
return max(0.0, min(1.0, similarity))
# ========== API Endpoints ==========
@app.get("/healthz")
async def health():
return {
"status": "ok",
"system": "AEFRS",
"timestamp": datetime.now().isoformat(),
"identities_count": len(identities_db),
"models": {
"arcface": embedding_session is not None
},
"mode": "production" if embedding_session else "fallback"
}
@app.post("/v1/token")
async def login(username: str = Form(...), password: str = Form(...)):
if username and password:
token = generate_token(username)
return {
"access_token": token,
"token_type": "bearer",
"expires_in": 86400,
"username": username
}
raise HTTPException(status_code=401, detail="Invalid credentials")
@app.post("/v1/enroll")
async def enroll(
credentials: HTTPAuthorizationCredentials = Depends(security),
identity_id: str = Form(...),
name: str = Form(...),
image: UploadFile = File(...)
):
username = verify_token(credentials)
try:
contents = await image.read()
embedding = extract_face_embedding(contents)
if embedding is None:
return JSONResponse(
status_code=400,
content={"status": "error", "message": "Could not extract face features"}
)
identities_db[identity_id] = {
"name": name,
"embedding": embedding.tolist(),
"created_at": datetime.now().isoformat(),
"created_by": username
}
vector_index.append({
"identity_id": identity_id,
"name": name,
"embedding": embedding.tolist()
})
save_data()
return {
"status": "success",
"message": f"تم تسجيل {name} بنجاح",
"identity_id": identity_id,
"name": name,
"embedding_dim": len(embedding),
"mode": "production" if embedding_session else "fallback"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/search")
async def search(
credentials: HTTPAuthorizationCredentials = Depends(security),
image: UploadFile = File(...),
top_k: int = Form(5)
):
username = verify_token(credentials)
try:
contents = await image.read()
query_embedding = extract_face_embedding(contents)
if query_embedding is None:
return JSONResponse(
status_code=400,
content={"status": "error", "message": "Could not extract face features"}
)
results = []
for item in vector_index:
similarity = cosine_similarity(query_embedding, item["embedding"])
results.append({
"identity_id": item["identity_id"],
"name": item["name"],
"similarity": similarity,
"similarity_percent": round(similarity * 100, 2)
})
results.sort(key=lambda x: x["similarity"], reverse=True)
results = results[:top_k]
return {
"status": "success",
"message": "Search completed",
"matches": results,
"total_matches": len(results),
"search_by": username,
"mode": "production" if embedding_session else "fallback"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/identities")
async def list_identities(
credentials: HTTPAuthorizationCredentials = Depends(security),
limit: int = 100,
offset: int = 0
):
verify_token(credentials)
identities_list = [
{
"identity_id": k,
"name": v["name"],
"created_at": v["created_at"]
}
for k, v in identities_db.items()
]
return {
"total": len(identities_list),
"identities": identities_list[offset:offset+limit]
}
@app.delete("/v1/identity/{identity_id}")
async def delete_identity(
identity_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
verify_token(credentials)
if identity_id not in identities_db:
raise HTTPException(status_code=404, detail="Identity not found")
del identities_db[identity_id]
global vector_index
vector_index = [item for item in vector_index if item["identity_id"] != identity_id]
save_data()
return {"message": f"Deleted {identity_id}"}
@app.post("/v1/logout")
async def logout(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
if token in active_tokens:
del active_tokens[token]
return {"message": "Logged out"}
raise HTTPException(status_code=400, detail="Invalid token")
# ========== واجهة المستخدم ==========
# (نفس الـ HTML السابق مع إضافة عرض وضع التشغيل)
HTML_PAGE = """
<!DOCTYPE html>
<html lang="ar" dir="rtl">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AEFRS - نظام التعرف على الوجوه</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container { max-width: 1400px; margin: 0 auto; }
.header { text-align: center; color: white; margin-bottom: 30px; }
.header h1 { font-size: 2.5em; margin-bottom: 10px; }
.status-bar {
background: rgba(255,255,255,0.2);
border-radius: 10px;
padding: 10px 20px;
margin-bottom: 20px;
text-align: center;
color: white;
}
.mode-production { background: #28a745; color: white; display: inline-block; padding: 3px 10px; border-radius: 20px; font-size: 12px; }
.mode-fallback { background: #ffc107; color: #333; display: inline-block; padding: 3px 10px; border-radius: 20px; font-size: 12px; }
.grid-2 {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(450px, 1fr));
gap: 20px;
}
.card {
background: white;
border-radius: 15px;
padding: 25px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
.card h2 {
color: #667eea;
margin-bottom: 20px;
border-bottom: 2px solid #667eea;
padding-bottom: 10px;
}
.form-group { margin-bottom: 15px; }
label { display: block; margin-bottom: 5px; font-weight: bold; color: #333; }
input, select {
width: 100%;
padding: 12px;
border: 1px solid #ddd;
border-radius: 8px;
font-size: 14px;
}
button {
width: 100%;
padding: 12px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
margin-top: 10px;
}
button:hover { transform: translateY(-2px); }
.result {
margin-top: 15px;
padding: 15px;
border-radius: 8px;
display: none;
}
.result.show { display: block; }
.success { background: #d4edda; color: #155724; border: 1px solid #c3e6cb; }
.error { background: #f8d7da; color: #721c24; border: 1px solid #f5c6cb; }
.info { background: #d1ecf1; color: #0c5460; border: 1px solid #bee5eb; }
.match-item {
padding: 12px;
margin: 8px 0;
background: #f8f9fa;
border-radius: 8px;
border-right: 4px solid #667eea;
}
.match-name { font-size: 16px; font-weight: bold; color: #333; }
.match-id { font-size: 12px; color: #666; margin: 4px 0; }
.match-score { font-size: 14px; font-weight: bold; margin-top: 5px; }
.score-high { color: #28a745; }
.score-medium { color: #ffc107; }
.score-low { color: #dc3545; }
.logout-btn { background: #dc3545; margin-top: 10px; }
.logout-btn:hover { background: #c82333; }
@media (max-width: 768px) { .grid-2 { grid-template-columns: 1fr; } }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🧠 AEFRS</h1>
<p>نظام التعرف على الوجوه المتقدم</p>
</div>
<div class="status-bar" id="statusBar">🔄 جاري الاتصال...</div>
<div class="grid-2">
<div class="card" id="loginCard">
<h2>🔐 تسجيل الدخول</h2>
<div class="form-group">
<label>👤 اسم المستخدم</label>
<input type="text" id="username" placeholder="أي اسم">
</div>
<div class="form-group">
<label>🔑 كلمة المرور</label>
<input type="password" id="password" placeholder="أي كلمة مرور">
</div>
<button onclick="login()">🚀 دخول</button>
<div id="loginResult" class="result"></div>
</div>
<div class="card" id="enrollCard" style="display:none">
<h2>📝 تسجيل وجه جديد</h2>
<div class="form-group">
<label>🆔 المعرف</label>
<input type="text" id="identityId" placeholder="مثال: 001">
</div>
<div class="form-group">
<label>👨 الاسم</label>
<input type="text" id="personName" placeholder="مثال: كريستيانو">
</div>
<div class="form-group">
<label>📸 الصورة</label>
<input type="file" id="enrollImage" accept="image/*">
</div>
<button onclick="enroll()">💾 تسجيل</button>
<div id="enrollResult" class="result"></div>
</div>
<div class="card" id="searchCard" style="display:none">
<h2>🔍 بحث عن وجه</h2>
<div class="form-group">
<label>📸 صورة للبحث</label>
<input type="file" id="searchImage" accept="image/*">
</div>
<div class="form-group">
<label>🔢 عدد النتائج</label>
<select id="topK">
<option value="3">3</option>
<option value="5" selected>5</option>
<option value="10">10</option>
</select>
</div>
<button onclick="search()">🔎 بحث</button>
<div id="searchResult" class="result"></div>
</div>
<div class="card" id="listCard" style="display:none">
<h2>📋 قائمة المسجلين</h2>
<button onclick="listIdentities()">📋 عرض الكل</button>
<button onclick="logout()" class="logout-btn">🚪 تسجيل خروج</button>
<div id="identitiesList" class="result"></div>
</div>
</div>
</div>
<script>
let token = null;
async function checkHealth() {
try {
const res = await fetch('/healthz');
const data = await res.json();
const bar = document.getElementById('statusBar');
if (res.ok) {
const modeText = data.mode === 'production' ? '🎯 وضع الإنتاج (ArcFace يعمل)' : '⚠️ وضع المحاكاة (دقة محدودة)';
const modeClass = data.mode === 'production' ? 'mode-production' : 'mode-fallback';
bar.innerHTML = `✅ النظام يعمل | 👥 ${data.identities_count} شخص | <span class="${modeClass}">${modeText}</span> | 🕐 ${new Date().toLocaleString('ar')}`;
}
} catch(e) {
document.getElementById('statusBar').innerHTML = '❌ خطأ في الاتصال';
}
}
async function login() {
const username = document.getElementById('username').value;
const password = document.getElementById('password').value;
const resultDiv = document.getElementById('loginResult');
if (!username || !password) {
showResult(resultDiv, 'error', '❌ أدخل اسم المستخدم وكلمة المرور');
return;
}
const formData = new FormData();
formData.append('username', username);
formData.append('password', password);
try {
const res = await fetch('/v1/token', { method: 'POST', body: formData });
const data = await res.json();
if (res.ok) {
token = data.access_token;
showResult(resultDiv, 'success', `✅ مرحباً ${data.username}`);
document.getElementById('loginCard').style.display = 'none';
document.getElementById('enrollCard').style.display = 'block';
document.getElementById('searchCard').style.display = 'block';
document.getElementById('listCard').style.display = 'block';
checkHealth();
} else {
showResult(resultDiv, 'error', `❌ ${data.detail}`);
}
} catch(e) {
showResult(resultDiv, 'error', `❌ ${e.message}`);
}
}
async function enroll() {
if (!token) { alert('سجل دخول أولاً'); return; }
const id = document.getElementById('identityId').value;
const name = document.getElementById('personName').value;
const file = document.getElementById('enrollImage').files[0];
const resultDiv = document.getElementById('enrollResult');
if (!id || !name || !file) {
showResult(resultDiv, 'error', '❌ املأ جميع الحقول');
return;
}
const formData = new FormData();
formData.append('identity_id', id);
formData.append('name', name);
formData.append('image', file);
try {
const res = await fetch('/v1/enroll', {
method: 'POST',
headers: { 'Authorization': `Bearer ${token}` },
body: formData
});
const data = await res.json();
if (res.ok && data.status === 'success') {
showResult(resultDiv, 'success', `✅ تم تسجيل ${name}`);
document.getElementById('identityId').value = '';
document.getElementById('personName').value = '';
document.getElementById('enrollImage').value = '';
checkHealth();
} else {
showResult(resultDiv, 'error', `❌ ${data.message || 'خطأ'}`);
}
} catch(e) {
showResult(resultDiv, 'error', `❌ ${e.message}`);
}
}
async function search() {
if (!token) { alert('سجل دخول أولاً'); return; }
const file = document.getElementById('searchImage').files[0];
const topK = document.getElementById('topK').value;
const resultDiv = document.getElementById('searchResult');
if (!file) {
showResult(resultDiv, 'error', '❌ اختر صورة');
return;
}
const formData = new FormData();
formData.append('image', file);
formData.append('top_k', topK);
try {
const res = await fetch('/v1/search', {
method: 'POST',
headers: { 'Authorization': `Bearer ${token}` },
body: formData
});
const data = await res.json();
if (res.ok && data.status === 'success') {
if (data.matches.length === 0) {
showResult(resultDiv, 'info', '⚠️ لا توجد نتائج');
} else {
let html = '<h3>🔍 النتائج:</h3>';
data.matches.forEach(m => {
const percent = m.similarity_percent;
let scoreClass = percent >= 70 ? 'score-high' : (percent >= 50 ? 'score-medium' : 'score-low');
html += `
<div class="match-item">
<div class="match-name">👤 ${m.name}</div>
<div class="match-id">🆔 ${m.identity_id}</div>
<div class="match-score ${scoreClass}">🎯 نسبة التشابه: ${percent}%</div>
</div>
`;
});
resultDiv.innerHTML = html;
resultDiv.className = 'result show success';
}
} else {
showResult(resultDiv, 'error', `❌ ${data.message || 'خطأ'}`);
}
} catch(e) {
showResult(resultDiv, 'error', `❌ ${e.message}`);
}
}
async function listIdentities() {
if (!token) { alert('سجل دخول أولاً'); return; }
const resultDiv = document.getElementById('identitiesList');
try {
const res = await fetch('/v1/identities', {
headers: { 'Authorization': `Bearer ${token}` }
});
const data = await res.json();
if (res.ok) {
if (data.total === 0) {
showResult(resultDiv, 'info', '⚠️ لا يوجد مسجلين');
} else {
let html = `<h3>📊 المجموع: ${data.total}</h3>`;
data.identities.forEach(i => {
html += `
<div class="match-item">
<div class="match-name">👤 ${i.name}</div>
<div class="match-id">🆔 ${i.identity_id}</div>
<div class="match-score">📅 ${new Date(i.created_at).toLocaleDateString('ar')}</div>
</div>
`;
});
resultDiv.innerHTML = html;
resultDiv.className = 'result show success';
}
}
} catch(e) {
showResult(resultDiv, 'error', `❌ ${e.message}`);
}
}
async function logout() {
if (token) {
try {
await fetch('/v1/logout', {
method: 'POST',
headers: { 'Authorization': `Bearer ${token}` }
});
} catch(e) {}
}
token = null;
document.getElementById('loginCard').style.display = 'block';
document.getElementById('enrollCard').style.display = 'none';
document.getElementById('searchCard').style.display = 'none';
document.getElementById('listCard').style.display = 'none';
document.getElementById('username').value = '';
document.getElementById('password').value = '';
showResult(document.getElementById('loginResult'), 'info', '👋 تم الخروج');
}
function showResult(el, type, msg) {
el.innerHTML = msg;
el.className = `result show ${type}`;
setTimeout(() => { if (el.innerHTML === msg) el.classList.remove('show'); }, 5000);
}
checkHealth();
setInterval(checkHealth, 10000);
</script>
</body>
</html>
"""
@app.get("/", response_class=HTMLResponse)
async def ui():
return HTMLResponse(content=HTML_PAGE)
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
print(f"🚀 Server on http://0.0.0.0:{port}")
print(f"📊 ArcFace mode: {'PRODUCTION' if embedding_session else 'FALLBACK (HOG)'}")
uvicorn.run(app, host="0.0.0.0", port=port)