issoufzousko07 commited on
Commit
76d2f94
·
0 Parent(s):

Final V2.2: Persistent Storage & Fixes

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ models/
2
+ data_storage/
3
+ elephmind.db
4
+ .env
5
+ venv/
6
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces Docker Configuration
2
+ FROM python:3.10-slim
3
+
4
+ # Create non-root user (required by HuggingFace)
5
+ RUN useradd -m -u 1000 user
6
+ ENV HOME=/home/user
7
+ ENV PATH="/home/user/.local/bin:$PATH"
8
+
9
+ WORKDIR /app
10
+
11
+ # Install system dependencies as root
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ libgl1 \
14
+ libglib2.0-0 \
15
+ libsm6 \
16
+ libxext6 \
17
+ libxrender1 \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Create directories as root BEFORE switching user
21
+ RUN mkdir -p /app/storage/uploads /app/storage/processed && \
22
+ chown -R user:user /app
23
+
24
+ # Switch to non-root user
25
+ USER user
26
+
27
+ # Copy requirements and install
28
+ COPY --chown=user requirements.txt .
29
+ RUN pip install --no-cache-dir --user -r requirements.txt
30
+
31
+ # Copy the rest of the application
32
+ COPY --chown=user . /app
33
+
34
+ # Expose port 7860 (required by Hugging Face Spaces)
35
+ EXPOSE 7860
36
+
37
+ # Run the application
38
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ElephMind Medical AI
3
+ emoji: 🏥
4
+ colorFrom: green
5
+ colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: true
9
+ license: apache-2.0
10
+ ---
11
+
12
+ # ElephMind - Diagnostic IA Médical
13
+
14
+ Application d'aide au diagnostic médical basée sur l'intelligence artificielle.
15
+
16
+ ## Fonctionnalités
17
+ - Analyse de radiographies thoraciques
18
+ - Analyse dermatologique
19
+ - Analyse histologique
20
+ - Analyse ophtalmologique
21
+ - Analyse orthopédique
database.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import os
3
+ import logging
4
+ from typing import Optional, List, Dict, Any
5
+
6
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ # HUGGING FACE PERSISTENCE FIX: Use /data if available
8
+ if os.path.exists('/data'):
9
+ DB_NAME = '/data/elephmind.db'
10
+ logging.info("Using PERSISTENT storage at /data/elephmind.db")
11
+ else:
12
+ DB_NAME = os.path.join(BASE_DIR, "elephmind.db")
13
+ logging.info(f"Using LOCAL storage at {DB_NAME}")
14
+
15
+ def get_db_connection():
16
+ conn = sqlite3.connect(DB_NAME)
17
+ conn.row_factory = sqlite3.Row
18
+ return conn
19
+
20
+ def init_db():
21
+ conn = get_db_connection()
22
+ c = conn.cursor()
23
+
24
+ # Create Users Table
25
+ c.execute('''
26
+ CREATE TABLE IF NOT EXISTS users (
27
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
28
+ username TEXT UNIQUE NOT NULL,
29
+ hashed_password TEXT NOT NULL,
30
+ email TEXT,
31
+ security_question TEXT NOT NULL,
32
+ security_answer TEXT NOT NULL,
33
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
34
+ )
35
+ ''')
36
+
37
+ # Create Feedback Table
38
+ c.execute('''
39
+ CREATE TABLE IF NOT EXISTS feedback (
40
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
41
+ username TEXT,
42
+ rating INTEGER,
43
+ comment TEXT,
44
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
45
+ )
46
+ ''')
47
+
48
+ # Create Audit Log Table (RGPD Compliance)
49
+ c.execute('''
50
+ CREATE TABLE IF NOT EXISTS audit_log (
51
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
52
+ username TEXT,
53
+ action TEXT NOT NULL,
54
+ resource TEXT,
55
+ ip_address TEXT,
56
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
57
+ )
58
+ ''')
59
+
60
+ conn.commit()
61
+ conn.close()
62
+ logging.info(f"Database {DB_NAME} initialized successfully.")
63
+
64
+ # --- User Operations ---
65
+
66
+ def create_user(user: Dict[str, Any]) -> bool:
67
+ try:
68
+ conn = get_db_connection()
69
+ c = conn.cursor()
70
+ c.execute('''
71
+ INSERT INTO users (username, hashed_password, email, security_question, security_answer)
72
+ VALUES (?, ?, ?, ?, ?)
73
+ ''', (
74
+ user['username'],
75
+ user['hashed_password'],
76
+ user.get('email', ''),
77
+ user['security_question'],
78
+ user['security_answer']
79
+ ))
80
+ conn.commit()
81
+ return True
82
+ except sqlite3.IntegrityError:
83
+ return False
84
+ except Exception as e:
85
+ logging.error(f"Error creating user: {e}")
86
+ return False
87
+ finally:
88
+ conn.close()
89
+
90
+ def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
91
+ conn = get_db_connection()
92
+ c = conn.cursor()
93
+ c.execute('SELECT * FROM users WHERE username = ?', (username,))
94
+ row = c.fetchone()
95
+ conn.close()
96
+ if row:
97
+ return dict(row)
98
+ return None
99
+
100
+ def update_password(username: str, new_hashed_password: str) -> bool:
101
+ try:
102
+ conn = get_db_connection()
103
+ c = conn.cursor()
104
+ c.execute('UPDATE users SET hashed_password = ? WHERE username = ?', (new_hashed_password, username))
105
+ conn.commit()
106
+ conn.close()
107
+ return True
108
+ except Exception as e:
109
+ logging.error(f"Error updating password: {e}")
110
+ return False
111
+
112
+ # --- Feedback Operations ---
113
+
114
+ def add_feedback(username: str, rating: int, comment: str):
115
+ conn = get_db_connection()
116
+ c = conn.cursor()
117
+ c.execute('INSERT INTO feedback (username, rating, comment) VALUES (?, ?, ?)', (username, rating, comment))
118
+ conn.commit()
119
+ conn.close()
120
+
121
+ # --- Audit Log Operations (RGPD Compliance) ---
122
+
123
+ def log_audit(username: str, action: str, resource: str = None, ip_address: str = None):
124
+ """Log user actions for RGPD compliance and security auditing."""
125
+ try:
126
+ conn = get_db_connection()
127
+ c = conn.cursor()
128
+ c.execute(
129
+ 'INSERT INTO audit_log (username, action, resource, ip_address) VALUES (?, ?, ?, ?)',
130
+ (username, action, resource, ip_address)
131
+ )
132
+ conn.commit()
133
+ conn.close()
134
+ except Exception as e:
135
+ logging.error(f"Error logging audit: {e}")
136
+
137
+ def get_user_audit_log(username: str, limit: int = 100) -> List[Dict[str, Any]]:
138
+ """Get audit log for a specific user."""
139
+ conn = get_db_connection()
140
+ c = conn.cursor()
141
+ c.execute(
142
+ 'SELECT * FROM audit_log WHERE username = ? ORDER BY created_at DESC LIMIT ?',
143
+ (username, limit)
144
+ )
145
+ rows = c.fetchall()
146
+ conn.close()
147
+ return [dict(row) for row in rows]
148
+
149
+ # --- Analysis Registry (REAL DATA ONLY) ---
150
+
151
+ def init_analysis_registry():
152
+ """Create the analysis_registry table if it doesn't exist."""
153
+ conn = get_db_connection()
154
+ c = conn.cursor()
155
+ c.execute('''
156
+ CREATE TABLE IF NOT EXISTS analysis_registry (
157
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
158
+ username TEXT NOT NULL,
159
+ domain TEXT NOT NULL,
160
+ top_diagnosis TEXT,
161
+ confidence REAL,
162
+ priority TEXT,
163
+ computation_time_ms INTEGER,
164
+ file_type TEXT,
165
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
166
+ )
167
+ ''')
168
+ conn.commit()
169
+ conn.close()
170
+
171
+ def log_analysis(
172
+ username: str,
173
+ domain: str,
174
+ top_diagnosis: str,
175
+ confidence: float,
176
+ priority: str,
177
+ computation_time_ms: int,
178
+ file_type: str
179
+ ) -> bool:
180
+ """Log a real analysis to the registry. NO FAKE DATA."""
181
+ try:
182
+ conn = get_db_connection()
183
+ c = conn.cursor()
184
+ c.execute('''
185
+ INSERT INTO analysis_registry
186
+ (username, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type)
187
+ VALUES (?, ?, ?, ?, ?, ?, ?)
188
+ ''', (username, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type))
189
+ conn.commit()
190
+ conn.close()
191
+ return True
192
+ except Exception as e:
193
+ logging.error(f"Error logging analysis: {e}")
194
+ return False
195
+
196
+ def get_dashboard_stats(username: str) -> Dict[str, Any]:
197
+ """Get real dashboard statistics for a user. Returns zeros if no data."""
198
+ conn = get_db_connection()
199
+ c = conn.cursor()
200
+
201
+ # Total count
202
+ c.execute('SELECT COUNT(*) FROM analysis_registry WHERE username = ?', (username,))
203
+ total = c.fetchone()[0]
204
+
205
+ # By domain
206
+ c.execute('''
207
+ SELECT domain, COUNT(*) as count
208
+ FROM analysis_registry
209
+ WHERE username = ?
210
+ GROUP BY domain
211
+ ''', (username,))
212
+ by_domain = {row['domain']: row['count'] for row in c.fetchall()}
213
+
214
+ # By priority
215
+ c.execute('''
216
+ SELECT priority, COUNT(*) as count
217
+ FROM analysis_registry
218
+ WHERE username = ?
219
+ GROUP BY priority
220
+ ''', (username,))
221
+ by_priority = {row['priority']: row['count'] for row in c.fetchall()}
222
+
223
+ # Average computation time
224
+ c.execute('''
225
+ SELECT AVG(computation_time_ms)
226
+ FROM analysis_registry
227
+ WHERE username = ?
228
+ ''', (username,))
229
+ avg_time = c.fetchone()[0] or 0
230
+
231
+ conn.close()
232
+
233
+ return {
234
+ "total_analyses": total,
235
+ "by_domain": by_domain,
236
+ "by_priority": by_priority,
237
+ "avg_computation_time_ms": round(avg_time, 0)
238
+ }
239
+
240
+ def get_recent_analyses(username: str, limit: int = 10) -> List[Dict[str, Any]]:
241
+ """Get recent real analyses for a user. Returns empty list if none."""
242
+ conn = get_db_connection()
243
+ c = conn.cursor()
244
+ c.execute('''
245
+ SELECT id, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type, created_at
246
+ FROM analysis_registry
247
+ WHERE username = ?
248
+ ORDER BY created_at DESC
249
+ LIMIT ?
250
+ ''', (username, limit))
251
+ rows = c.fetchall()
252
+ conn.close()
253
+ return [dict(row) for row in rows]
254
+
255
+
encryption.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cryptography.fernet import Fernet
2
+ import os
3
+ import sys
4
+ import logging
5
+ from typing import Optional
6
+
7
+ # -------------------------------------------------------------------------
8
+ # ENCRYPTION CONFIGURATION - PRODUCTION READY
9
+ # -------------------------------------------------------------------------
10
+
11
+ # Environment detection
12
+ ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
13
+ IS_PRODUCTION = ENVIRONMENT == "production"
14
+
15
+ # Encryption Key - Load from environment variable
16
+ ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY")
17
+
18
+ if not ENCRYPTION_KEY:
19
+ if IS_PRODUCTION:
20
+ logging.critical("🔴 FATAL ERROR: ENCRYPTION_KEY must be set in production environment")
21
+ logging.critical("Generate one with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'")
22
+ sys.exit(1) # Fail-fast in production
23
+ else:
24
+ # Development fallback with ephemeral key
25
+ ENCRYPTION_KEY = Fernet.generate_key().decode()
26
+ logging.warning("⚠️ WARNING: Using ephemeral encryption key (development only)")
27
+
28
+ # Initialize cipher
29
+ cipher_suite = Fernet(ENCRYPTION_KEY.encode() if isinstance(ENCRYPTION_KEY, str) else ENCRYPTION_KEY)
30
+
31
+ def encrypt_data(data: str) -> str:
32
+ """
33
+ Encrypts a string and returns the encrypted token as a string.
34
+ """
35
+ if not data: return ""
36
+ encrypted_bytes = cipher_suite.encrypt(data.encode('utf-8'))
37
+ return encrypted_bytes.decode('utf-8')
38
+
39
+ def decrypt_data(token: str) -> Optional[str]:
40
+ """
41
+ Decrypts a token and returns the original string.
42
+ """
43
+ if not token: return None
44
+ try:
45
+ decrypted_bytes = cipher_suite.decrypt(token.encode('utf-8'))
46
+ return decrypted_bytes.decode('utf-8')
47
+ except Exception as e:
48
+ print(f"Decryption failed: {e}")
49
+ return None
50
+
51
+ def rotate_key():
52
+ """
53
+ Example function to rotate keys (advanced).
54
+ """
55
+ global key, cipher_suite
56
+ key = Fernet.generate_key()
57
+ cipher_suite = Fernet(key)
58
+ with open(ENCRYPTION_KEY_PATH, "wb") as key_file:
59
+ key_file.write(key)
60
+ print(f"New key generated and saved to {ENCRYPTION_KEY_PATH}")
61
+
62
+ if __name__ == "__main__":
63
+ # Test
64
+ original = "Jean Dupont - Patient Zero"
65
+ encrypted = encrypt_data(original)
66
+ decrypted = decrypt_data(encrypted)
67
+
68
+ print(f"Original: {original}")
69
+ print(f"Encrypted: {encrypted}")
70
+ print(f"Decrypted: {decrypted}")
71
+ assert original == decrypted
localization.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mappings de localisation (Anglais -> Français)
2
+ # Ce fichier permet de traduire les résultats de l'IA sans modifier les prompts originaux
3
+ # qui doivent rester en anglais pour la performance du modèle.
4
+
5
+ DOMAIN_TRANSLATIONS = {
6
+ 'Thoracic': {
7
+ 'label': 'Thoracique',
8
+ 'description': 'Analyse Radiographique du Thorax'
9
+ },
10
+ 'Dermatology': {
11
+ 'label': 'Dermatologie',
12
+ 'description': 'Analyse Dermatoscope des Lésions Cutanées'
13
+ },
14
+ 'Histology': {
15
+ 'label': 'Histologie',
16
+ 'description': 'Analyse Microscopique (H&E)'
17
+ },
18
+ 'Ophthalmology': {
19
+ 'label': 'Ophtalmologie',
20
+ 'description': 'Fond d\'Oeil (Rétine)'
21
+ },
22
+ 'Orthopedics': {
23
+ 'label': 'Orthopédie',
24
+ 'description': 'Radiographie Osseuse'
25
+ }
26
+ }
27
+
28
+ LABEL_TRANSLATIONS = {
29
+ # --- THORACIC ---
30
+ 'Diffuse interstitial opacities or ground-glass pattern (Viral/Atypical Pneumonia)':
31
+ 'Opacités interstitielles diffuses ou aspect en verre dépoli (Pneumonie Virale/Atypique)',
32
+
33
+ 'Focal alveolar consolidation with air bronchograms (Bacterial Pneumonia)':
34
+ 'Condensation alvéolaire focale avec bronchogrammes aériens (Pneumonie Bactérienne)',
35
+
36
+ 'Perfectly clear lungs, sharp costophrenic angles, no pathology':
37
+ 'Poumons parfaitement clairs, angles costophréniques nets, aucune pathologie',
38
+
39
+ 'Pneumothorax (Lung collapse)': 'Pneumothorax (Décollement de la plèvre)',
40
+ 'Pleural Effusion (Fluid)': 'Épanchement Pleural (Liquide)',
41
+ 'Cardiomegaly (Enlarged heart)': 'Cardiomégalie (Cœur élargi)',
42
+ 'Pulmonary Edema': 'Œdème Pulmonaire',
43
+ 'Lung Nodule or Mass': 'Nodule ou Masse Pulmonaire',
44
+ 'Atelectasis (Lung collapse)': 'Atélectasie (Affaissement pulmonaire)',
45
+
46
+ # --- DERMATOLOGY ---
47
+ 'A healthy skin area without lesion': 'Zone de peau saine sans lésion',
48
+ 'A benign nevus (mole) regular, symmetrical and homogeneous': 'Nævus bénin (grain de beauté) régulier, symétrique et homogène',
49
+ 'A seborrheic keratosis (benign warty lesion)': 'Kératose séborrhéique (lésion verruqueuse bénigne)',
50
+ 'A malignant melanoma with asymmetry, irregular borders and multiple colors': 'Mélanome malin (Asymétrie, Bords irréguliers, Couleurs multiples)',
51
+ 'A basal cell carcinoma (pearly or ulcerated lesion)': 'Carcinome basocellulaire (lésion perlée ou ulcérée)',
52
+ 'A squamous cell carcinoma (crusty or budding lesion)': 'Carcinome épidermoïde (lésion croûteuse ou bourgeonnante)',
53
+ 'A non-specific inflammatory skin lesion': 'Lésion cutanée inflammatoire non spécifique',
54
+
55
+ # --- ORTHOPEDICS ---
56
+ 'Severe osteoarthritis with bone-on-bone contact and large osteophytes (Grade 4)': 'Arthrose sévère avec contact os-contre-os et ostéophytes importants (Grade 4)',
57
+ 'Moderate osteoarthritis with definite joint space narrowing (Grade 2-3)': 'Arthrose modérée avec pincement articulaire net (Grade 2-3)',
58
+ 'Normal knee joint with preserved joint space and no osteophytes (Grade 0-1)': 'Genou normal, interligne articulaire préservé (Grade 0-1)',
59
+ 'Total knee arthroplasty (TKA) with metallic implant': 'Prothèse totale de genou (implant métallique)',
60
+ 'Acute knee fracture or dislocation': 'Fracture ou luxation aiguë du genou',
61
+ 'Other x-ray view (Chest, Hand, Foot, Pediatric) - OUT OF DISTRIBUTION': 'Autre vue radiographique (Hors périmètre)',
62
+ 'A knee x-ray view (Knee Joint)': 'Radiographie du Genou'
63
+ }
64
+
65
+ def localize_result(result_json):
66
+ """
67
+ Traduit les résultats bruts (Anglais) en Français
68
+ en utilisant les dictionnaires de mapping.
69
+ """
70
+ # 1. Localiser le Domaine
71
+ domain_key = result_json['domain']['label']
72
+ if domain_key in DOMAIN_TRANSLATIONS:
73
+ result_json['domain']['label'] = DOMAIN_TRANSLATIONS[domain_key]['label']
74
+ result_json['domain']['description'] = DOMAIN_TRANSLATIONS[domain_key]['description']
75
+
76
+ # 2. Localiser les Résultats Spécifiques
77
+ for item in result_json['specific']:
78
+ original_label = item['label']
79
+ if original_label in LABEL_TRANSLATIONS:
80
+ item['label'] = LABEL_TRANSLATIONS[original_label]
81
+ # Si pas de traduction trouvée, on garde l'anglais (fallback)
82
+
83
+ return result_json
main.py ADDED
@@ -0,0 +1,1478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ElephMind Medical AI Backend
3
+ ============================
4
+ Production-ready FastAPI backend for medical image analysis using SigLIP.
5
+
6
+ Author: ElephMind Team
7
+ Version: 2.0.0 (Cleaned & Secured)
8
+ """
9
+
10
+ import sys
11
+ import os
12
+ import uuid
13
+ import asyncio
14
+ import time
15
+ import logging
16
+
17
+ # --- DOTENV SUPPORT (MUST BE FIRST) ---
18
+ try:
19
+ from dotenv import load_dotenv
20
+ load_dotenv()
21
+ except ImportError:
22
+ pass
23
+ from enum import Enum
24
+ from typing import Dict, List, Optional, Any, Tuple
25
+ from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
26
+ from fastapi.middleware.cors import CORSMiddleware
27
+ from pydantic import BaseModel
28
+ import uvicorn
29
+ from contextlib import asynccontextmanager
30
+ import base64
31
+ import cv2
32
+ import numpy as np
33
+ from pytorch_grad_cam import GradCAMPlusPlus
34
+ from pytorch_grad_cam.utils.image import show_cam_on_image
35
+ from localization import localize_result
36
+ import torch
37
+ import torch.nn as nn
38
+ from storage import get_storage_provider
39
+ import encryption
40
+ import database
41
+ # algorithms imported directly above
42
+
43
+ import math
44
+ from collections import deque
45
+ from dataclasses import dataclass, field
46
+ from PIL import Image
47
+ import io
48
+
49
+ # --- GRADCAM UTILS FOR SIGLIP/ViT ---
50
+ class HuggingFaceWeirdCLIPWrapper(torch.nn.Module):
51
+ def __init__(self, model, input_ids, attention_mask):
52
+ super(HuggingFaceWeirdCLIPWrapper, self).__init__()
53
+ self.model = model
54
+ self.input_ids = input_ids
55
+ self.attention_mask = attention_mask
56
+
57
+ def forward(self, input_tensor):
58
+ # input_tensor is pixel_values
59
+ return self.model(
60
+ input_ids=self.input_ids,
61
+ pixel_values=input_tensor,
62
+ attention_mask=self.attention_mask
63
+ ).logits_per_image
64
+
65
+ def reshape_transform(tensor, height=14, width=14):
66
+ # For SigLIP / ViT-based models
67
+ # Tensor shape: [Batch, Tokens, Channels]
68
+ # Remove CLS token (index 0)
69
+ result = tensor[:, 1:, :]
70
+
71
+ # Heuristic to find square grid size
72
+ # Assuming batch size 1
73
+ seq_len = result.size(1)
74
+ grid_size = int(math.sqrt(seq_len))
75
+
76
+ # Reshape to (Batch, Grid, Grid, Channels)
77
+ result = result.reshape(tensor.size(0), grid_size, grid_size, tensor.size(2))
78
+
79
+ # Transpose to (Batch, Channels, Grid, Grid) for GradCAM
80
+ result = result.transpose(2, 3).transpose(1, 2)
81
+ return result
82
+
83
+ # --- AUTH IMPORTS ---
84
+ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
85
+ from fastapi import Depends, status, Request
86
+ from datetime import datetime, timedelta
87
+ from jose import JWTError, jwt
88
+ import bcrypt
89
+
90
+ # --- DOTENV (Moved to top) ---
91
+
92
+ # =========================================================================
93
+ # LOGGING CONFIGURATION
94
+ # =========================================================================
95
+ logging.basicConfig(
96
+ level=logging.INFO,
97
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
98
+ handlers=[logging.StreamHandler(sys.stdout)]
99
+ )
100
+ logger = logging.getLogger("ElephMind-Backend")
101
+
102
+ # =========================================================================
103
+ # 7 INTELLIGENCE ALGORITHMS (Merged from algorithms.py)
104
+ # =========================================================================
105
+
106
+ # 1. IMAGE QUALITY ASSESSMENT
107
+ def detect_blur(image: np.ndarray) -> float:
108
+ """
109
+ Detect blur using Laplacian variance.
110
+ Higher score = sharper image.
111
+ Returns: 0.0 (very blurry) to 1.0 (very sharp)
112
+ """
113
+ if len(image.shape) == 3:
114
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
115
+ else:
116
+ gray = image
117
+
118
+ laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
119
+ # Normalize to 0-1 (empirical thresholds for medical images)
120
+ return min(1.0, laplacian_var / 500.0)
121
+
122
+ def assess_image_quality(image: np.ndarray) -> Dict[str, Any]:
123
+ """Assess image quality metrics."""
124
+ score = 0
125
+ metrics = []
126
+
127
+ # Blur detection
128
+ sharpness = detect_blur(image)
129
+ metrics.append({"metric": "Netteté", "value": int(sharpness * 100)})
130
+
131
+ if sharpness > 0.6: score += 40
132
+ elif sharpness > 0.3: score += 20
133
+
134
+ # Contrast check
135
+ if len(image.shape) == 3:
136
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
137
+ else:
138
+ gray = image
139
+
140
+ contrast = float(gray.std())
141
+ metrics.append({"metric": "Contraste", "value": int(min(100, contrast * 2))})
142
+ if contrast > 40: score += 30
143
+
144
+ # Resolution check
145
+ h, w = image.shape[:2]
146
+ metrics.append({"metric": "Résolution", "value": int(min(100, (h*w)/(1024*1024)*100))})
147
+ if h*w > 512*512: score += 30
148
+
149
+ return {
150
+ "quality_score": min(100, score),
151
+ "metrics": metrics
152
+ }
153
+
154
+ # 2. CONFIDENCE CALIBRATION
155
+ def calibrate_confidence(raw_stats: List[float], labels: List[str]) -> float:
156
+ """
157
+ Calibrate raw confidence scores.
158
+ """
159
+ if not raw_stats:
160
+ return 0.0
161
+
162
+ # Example Logic: Weighted average of top 2 predictions
163
+ top_val = max(raw_stats)
164
+
165
+ # Simple boost for demo purposes
166
+ calibrated = min(0.99, top_val * 1.1)
167
+
168
+ return float(round(calibrated * 100, 2))
169
+
170
+ # 3. CLINICAL PRIORITY SCORING
171
+ def calculate_priority_score(predictions: List[Dict], domain: str) -> str:
172
+ """
173
+ Determine triage priority based on prediction severity.
174
+ """
175
+ if not predictions:
176
+ return "Normale"
177
+
178
+ top_pred = predictions[0]
179
+ label = top_pred["label"].lower()
180
+ prob = top_pred["probability"]
181
+
182
+ # Critical keywords
183
+ critical_terms = ["malignant", "cancer", "carcinoma", "pneumonia", "pneumothorax", "fracture", "grade 4"]
184
+ warning_terms = ["grade 2", "grade 3", "effusion", "edema", "abnormal"]
185
+
186
+ if any(term in label for term in critical_terms) and prob > 50:
187
+ return "Élevée"
188
+ if any(term in label for term in warning_terms) and prob > 40:
189
+ return "Moyenne"
190
+
191
+ return "Normale"
192
+
193
+ # 4. AUTOMATIC REPORT GENERATION
194
+ def generate_clinical_report(analysis_result: Dict[str, Any], patient_info: Optional[Dict] = None) -> str:
195
+ """
196
+ Generate a text summary of the findings using templates (Deterministic LLM-like).
197
+ """
198
+ domain = analysis_result.get("domain", {}).get("label", "Unknown")
199
+ specifics = analysis_result.get("specific", [])
200
+
201
+ if not specifics:
202
+ return "Analyse non concluante."
203
+
204
+ top_finding = specifics[0]
205
+
206
+ report = f"RAPPORT D'ANALYSE AUTOMATISÉE - {domain.upper()}\n"
207
+ report += f"Date: {datetime.now().strftime('%d/%m/%Y %H:%M')}\n"
208
+ if patient_info:
209
+ report += f"Patient ID: {patient_info.get('id', 'N/A')}\n"
210
+ report += "-" * 40 + "\n"
211
+
212
+ report += f"Observation Principale: {top_finding['label']}\n"
213
+ report += f"Confiance IA: {top_finding['probability']}%\n"
214
+ priority = analysis_result.get("priority", "Normale")
215
+ report += f"Priorité de Triage: {priority.upper()}\n\n"
216
+
217
+ report += "Détails Techniques:\n"
218
+ for i, det in enumerate(specifics[1:4]):
219
+ report += f"- {det['label']}: {det['probability']}%\n"
220
+
221
+ return report
222
+
223
+ # 5. SIMILAR CASE DETECTION (Vector DB Mockup)
224
+ @dataclass
225
+ class CaseRecord:
226
+ id: str
227
+ embedding: np.ndarray
228
+ diagnosis: str
229
+ domain: str
230
+ probability: float
231
+
232
+ class SimilarCaseDatabase:
233
+ def __init__(self):
234
+ self.cases: List[CaseRecord] = []
235
+
236
+ def add_case(self, case_id: str, embedding: np.ndarray, diagnosis: str, domain: str, probability: float):
237
+ self.cases.append(CaseRecord(case_id, embedding, diagnosis, domain, probability))
238
+ # Keep manageable size
239
+ if len(self.cases) > 1000:
240
+ self.cases.pop(0)
241
+
242
+ def find_similar(self, query_embedding: np.ndarray, top_k: int = 3, same_domain_only: bool = True, query_domain: str = None) -> List[Dict]:
243
+ if not self.cases:
244
+ return []
245
+
246
+ scores = []
247
+ for case in self.cases:
248
+ if same_domain_only and query_domain and case.domain != query_domain:
249
+ continue
250
+
251
+ # Cosine similarity
252
+ dot_product = np.dot(query_embedding, case.embedding)
253
+ norm_a = np.linalg.norm(query_embedding)
254
+ norm_b = np.linalg.norm(case.embedding)
255
+ similarity = dot_product / (norm_a * norm_b) if norm_a > 0 and norm_b > 0 else 0
256
+
257
+ scores.append((similarity, case))
258
+
259
+ scores.sort(key=lambda x: x[0], reverse=True)
260
+ return [
261
+ {
262
+ "case_id": c.id,
263
+ "diagnosis": c.diagnosis,
264
+ "similarity": round(float(s * 100), 1)
265
+ }
266
+ for s, c in scores[:top_k]
267
+ ]
268
+
269
+ # Global instance
270
+ similar_case_db = SimilarCaseDatabase()
271
+
272
+ def find_similar_cases(embedding: np.ndarray, domain: str, top_k: int = 5) -> Dict[str, Any]:
273
+ """Find similar cases based on embedding."""
274
+ similar = similar_case_db.find_similar(
275
+ query_embedding=embedding,
276
+ top_k=top_k,
277
+ same_domain_only=True,
278
+ query_domain=domain
279
+ )
280
+
281
+ return {
282
+ "similar_cases": similar,
283
+ "cases_searched": len(similar_case_db.cases),
284
+ "message": f"Trouvé {len(similar)} cas similaires" if similar else "Aucun cas similaire trouvé"
285
+ }
286
+
287
+ def store_case_for_similarity(case_id: str, embedding: np.ndarray, diagnosis: str, domain: str, probability: float):
288
+ """Store a case for future similarity searches."""
289
+ similar_case_db.add_case(
290
+ case_id=case_id,
291
+ embedding=embedding,
292
+ diagnosis=diagnosis,
293
+ domain=domain,
294
+ probability=probability
295
+ )
296
+
297
+ # 6. ADAPTIVE PREPROCESSING
298
+ def estimate_noise_level(image: np.ndarray) -> float:
299
+ """Estimate noise level using Laplacian method."""
300
+ if len(image.shape) == 3:
301
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
302
+ else:
303
+ gray = image
304
+
305
+ # Use robust median absolute deviation
306
+ laplacian = cv2.Laplacian(gray, cv2.CV_64F)
307
+ sigma = np.median(np.abs(laplacian)) / 0.6745
308
+ return float(sigma)
309
+
310
+ def apply_clahe(image: np.ndarray, clip_limit: float = 2.0, grid_size: int = 8) -> np.ndarray:
311
+ """Apply Contrast Limited Adaptive Histogram Equalization."""
312
+ if len(image.shape) == 3:
313
+ # Convert to LAB and apply to L channel
314
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
315
+ clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(grid_size, grid_size))
316
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
317
+ return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
318
+ else:
319
+ clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(grid_size, grid_size))
320
+ return clahe.apply(image)
321
+
322
+ def gamma_correction(image: np.ndarray, gamma: float = 1.0) -> np.ndarray:
323
+ """Apply gamma correction for brightness adjustment."""
324
+ inv_gamma = 1.0 / gamma
325
+ table = np.array([
326
+ ((i / 255.0) ** inv_gamma) * 255
327
+ for i in np.arange(0, 256)
328
+ ]).astype("uint8")
329
+ return cv2.LUT(image, table)
330
+
331
+ def bilateral_denoise(image: np.ndarray, d: int = 9, sigma_color: int = 75, sigma_space: int = 75) -> np.ndarray:
332
+ """Apply bilateral filter for edge-preserving denoising."""
333
+ return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
334
+
335
+ def adaptive_preprocessing(image_bytes: bytes) -> Tuple[Image.Image, Dict[str, Any]]:
336
+ """
337
+ Apply intelligent preprocessing based on image analysis.
338
+ Returns processed image and a log of transformations applied.
339
+ """
340
+ # Decode image
341
+ nparr = np.frombuffer(image_bytes, np.uint8)
342
+ img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
343
+
344
+ if img is None:
345
+ raise ValueError("Could not decode image")
346
+
347
+ transformations = []
348
+ original_stats = {
349
+ "mean_brightness": float(np.mean(img)),
350
+ "std_dev": float(np.std(img))
351
+ }
352
+
353
+ # Convert to grayscale for analysis
354
+ if len(img.shape) == 3:
355
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
356
+ else:
357
+ gray = img
358
+
359
+ # Analyze histogram
360
+ hist = cv2.calcHist([gray], [0], None, [256], [0, 256]).flatten()
361
+ non_zero = np.where(hist > 0)[0]
362
+
363
+ is_low_contrast = bool(len(non_zero) > 0 and (non_zero[-1] - non_zero[0]) < 150)
364
+ is_dark = bool(np.mean(gray) < 60)
365
+ is_bright = bool(np.mean(gray) > 200)
366
+ noise_level = float(estimate_noise_level(gray))
367
+
368
+ # Apply adaptive corrections
369
+ processed = img.copy()
370
+
371
+ # 1. Low contrast - Apply CLAHE
372
+ if is_low_contrast:
373
+ processed = apply_clahe(processed, clip_limit=2.5)
374
+ transformations.append({
375
+ "type": "CLAHE",
376
+ "reason": "Faible contraste détecté",
377
+ "params": {"clip_limit": 2.5}
378
+ })
379
+
380
+ # 2. Dark image - Gamma correction
381
+ if is_dark:
382
+ processed = gamma_correction(processed, gamma=0.6)
383
+ transformations.append({
384
+ "type": "Gamma Correction",
385
+ "reason": "Image trop sombre",
386
+ "params": {"gamma": 0.6}
387
+ })
388
+
389
+ # 3. Overexposed - Inverse gamma
390
+ if is_bright:
391
+ processed = gamma_correction(processed, gamma=1.6)
392
+ transformations.append({
393
+ "type": "Gamma Correction",
394
+ "reason": "Image surexposée",
395
+ "params": {"gamma": 1.6}
396
+ })
397
+
398
+ # 4. Noisy - Bilateral filter
399
+ if noise_level > 15:
400
+ processed = bilateral_denoise(processed)
401
+ transformations.append({
402
+ "type": "Bilateral Denoise",
403
+ "reason": f"Bruit détecté (σ={noise_level:.1f})",
404
+ "params": {"d": 9, "sigma": 75}
405
+ })
406
+
407
+ # 5. Black level correction for X-rays (crush blacks)
408
+ if len(processed.shape) == 2 or (len(processed.shape) == 3 and processed.shape[2] == 1):
409
+ _, processed = cv2.threshold(processed, 15, 255, cv2.THRESH_TOZERO)
410
+ transformations.append({
411
+ "type": "Black Level Crush",
412
+ "reason": "Correction niveau noir (X-ray)",
413
+ "params": {"threshold": 15}
414
+ })
415
+
416
+ # Final normalization
417
+ min_val, max_val = processed.min(), processed.max()
418
+ if max_val > min_val:
419
+ processed = ((processed - min_val) / (max_val - min_val) * 255).astype(np.uint8)
420
+ transformations.append({
421
+ "type": "Normalization",
422
+ "reason": "Normalisation finale",
423
+ "params": {"min": float(min_val), "max": float(max_val)}
424
+ })
425
+
426
+ # Convert to PIL Image
427
+ if len(processed.shape) == 2:
428
+ pil_image = Image.fromarray(processed).convert("RGB")
429
+ else:
430
+ pil_image = Image.fromarray(cv2.cvtColor(processed, cv2.COLOR_BGR2RGB))
431
+
432
+ preprocessing_log = {
433
+ "original_stats": original_stats,
434
+ "analysis": {
435
+ "low_contrast": is_low_contrast,
436
+ "dark": is_dark,
437
+ "bright": is_bright,
438
+ "noise_level": round(noise_level, 2)
439
+ },
440
+ "transformations_applied": transformations,
441
+ "transformation_count": len(transformations)
442
+ }
443
+
444
+ return pil_image, preprocessing_log
445
+
446
+ # 7. ENHANCE ANALYSIS RESULT (PIPELINE)
447
+ def enhance_analysis_result(
448
+ base_result: Dict[str, Any],
449
+ image_array: np.ndarray = None,
450
+ embedding: np.ndarray = None,
451
+ case_id: str = None,
452
+ patient_info: Dict = None
453
+ ) -> Dict[str, Any]:
454
+ """
455
+ Enhance base analysis result with all 7 algorithms.
456
+ This is the main entry point for the enhanced pipeline.
457
+ """
458
+ enhanced = base_result.copy()
459
+
460
+ # 1. Image Quality (if image provided)
461
+ if image_array is not None:
462
+ enhanced["image_quality"] = assess_image_quality(image_array)
463
+
464
+ # 2. Confidence Calibration
465
+ if "specific" in enhanced and enhanced["specific"]:
466
+ raw_probs = [p["probability"] / 100 for p in enhanced["specific"]]
467
+ labels = [p["label"] for p in enhanced["specific"]]
468
+ enhanced["confidence"] = calibrate_confidence(raw_probs, labels=labels)
469
+
470
+ # 3. Priority Scoring
471
+ if "specific" in enhanced and enhanced["specific"]:
472
+ domain = enhanced.get("domain", {}).get("label", "Unknown")
473
+ enhanced["priority"] = calculate_priority_score(enhanced["specific"], domain)
474
+
475
+ # 4. Similar Cases (if embedding provided)
476
+ if embedding is not None and "domain" in enhanced:
477
+ domain = enhanced["domain"].get("label", "Unknown")
478
+ enhanced["similar_cases"] = find_similar_cases(embedding, domain)
479
+
480
+ # Store this case for future searches
481
+ if case_id and enhanced["specific"]:
482
+ top_pred = enhanced["specific"][0]
483
+ store_case_for_similarity(
484
+ case_id=case_id,
485
+ embedding=embedding,
486
+ diagnosis=top_pred["label"],
487
+ domain=domain,
488
+ probability=top_pred["probability"]
489
+ )
490
+
491
+ # 5. Generate Report
492
+ enhanced["report"] = generate_clinical_report(
493
+ enhanced,
494
+ patient_info=patient_info
495
+ )
496
+
497
+ return enhanced
498
+
499
+ BASE_MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
500
+ NESTED_DIR = os.path.join(BASE_MODELS_DIR, "oeil d'elephant")
501
+ MODEL_DIR = NESTED_DIR if os.path.exists(NESTED_DIR) else BASE_MODELS_DIR
502
+
503
+ # Environment Detection
504
+ ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
505
+ IS_PRODUCTION = ENVIRONMENT == "production"
506
+
507
+ # Security Configuration - JWT Secret Key (ENFORCED in production)
508
+ SECRET_KEY = os.getenv("JWT_SECRET_KEY")
509
+ if not SECRET_KEY:
510
+ if IS_PRODUCTION:
511
+ logger.critical("🔴 FATAL ERROR: JWT_SECRET_KEY must be set in production environment")
512
+ logger.critical("Generate one with: python -c 'import secrets; print(secrets.token_hex(32))'")
513
+ sys.exit(1) # Fail-fast in production
514
+ else:
515
+ # Development fallback with warning
516
+ from secrets import token_hex
517
+ SECRET_KEY = "dev_insecure_key_" + token_hex(16)
518
+ logger.warning("⚠️ WARNING: Using development JWT secret. DO NOT use in production!")
519
+
520
+ ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256")
521
+ ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "60"))
522
+
523
+ logger.info(f"🌍 Environment: {ENVIRONMENT}")
524
+ logger.info(f"✅ JWT SECRET_KEY: {'SET (secure)' if 'dev_insecure' not in SECRET_KEY else 'DEVELOPMENT MODE'}")
525
+
526
+ # CORS Configuration
527
+ CORS_ORIGINS_STR = os.getenv("CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173")
528
+ CORS_ORIGINS = [origin.strip() for origin in CORS_ORIGINS_STR.split(",")]
529
+
530
+ # Concurrency Control
531
+ MAX_CONCURRENT_USERS = int(os.getenv("MAX_CONCURRENT_USERS", "200"))
532
+ concurrency_semaphore = asyncio.Semaphore(MAX_CONCURRENT_USERS)
533
+
534
+ # =========================================================================
535
+ # MODEL PATH CONFIGURATION (HuggingFace Hub or Local)
536
+ # =========================================================================
537
+ def get_model_path():
538
+ """Get model path - download from HuggingFace Hub if not available locally."""
539
+ # Check environment variable first
540
+ env_path = os.getenv("MODEL_DIR")
541
+ if env_path and os.path.exists(env_path):
542
+ logger.info(f"Using model from environment: {env_path}")
543
+ return env_path
544
+
545
+ # Check local path (development)
546
+ local_path = os.path.join(os.path.dirname(__file__), "models", "oeil d'elephant")
547
+ if os.path.exists(local_path):
548
+ logger.info(f"Using local model: {local_path}")
549
+ return local_path
550
+
551
+ # Download from HuggingFace Hub (production/cloud)
552
+ try:
553
+ from huggingface_hub import snapshot_download
554
+ logger.info("Downloading model from HuggingFace Hub...")
555
+ hub_path = snapshot_download(
556
+ repo_id="issoufzousko07/medsigclip-model",
557
+ repo_type="model"
558
+ )
559
+ logger.info(f"Model downloaded to: {hub_path}")
560
+ return hub_path
561
+ except Exception as e:
562
+ logger.error(f"Failed to download model: {e}")
563
+ raise RuntimeError(f"Model not found locally and failed to download: {e}")
564
+
565
+ MODEL_DIR = None # Will be set at startup
566
+
567
+ # OAuth2 Scheme
568
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
569
+
570
+ # =========================================================================
571
+ # MEDICAL DOMAINS CONFIGURATION
572
+ # =========================================================================
573
+ MEDICAL_DOMAINS = {
574
+ 'Thoracic': {
575
+ 'domain_prompt': 'Chest X-Ray Analysis',
576
+ 'specific_labels': [
577
+ 'Diffuse interstitial opacities or ground-glass pattern (Viral/Atypical Pneumonia)',
578
+ 'Focal alveolar consolidation with air bronchograms (Bacterial Pneumonia)',
579
+ 'Perfectly clear lungs, sharp costophrenic angles, no pathology',
580
+ 'Pneumothorax (Lung collapse)',
581
+ 'Pleural Effusion (Fluid)',
582
+ 'Cardiomegaly (Enlarged heart)',
583
+ 'Pulmonary Edema',
584
+ 'Lung Nodule or Mass',
585
+ 'Atelectasis (Lung collapse)'
586
+ ]
587
+ },
588
+ 'Dermatology': {
589
+ 'domain_prompt': 'Dermatoscopic analysis of a pigmented or non-pigmented skin lesion',
590
+ 'specific_labels': [
591
+ 'A healthy skin area without lesion',
592
+ 'A benign nevus (mole) regular, symmetrical and homogeneous',
593
+ 'A seborrheic keratosis (benign warty lesion)',
594
+ 'A malignant melanoma with asymmetry, irregular borders and multiple colors',
595
+ 'A basal cell carcinoma (pearly or ulcerated lesion)',
596
+ 'A squamous cell carcinoma (crusty or budding lesion)',
597
+ 'A non-specific inflammatory skin lesion'
598
+ ]
599
+ },
600
+ 'Histology': {
601
+ 'domain_prompt': 'Microscopic analysis of a histological section (H&E stain)',
602
+ 'specific_labels': [
603
+ 'Healthy breast tissue with preserved lobular architecture',
604
+ 'Healthy prostatic tissue with regular glands',
605
+ 'Invasive ductal carcinoma of the breast (Disorganized cells)',
606
+ 'Prostate adenocarcinoma (Gland fusion)',
607
+ 'Cervical dysplasia or intraepithelial neoplasia',
608
+ 'Colon cancer tumor tissue',
609
+ 'Lung cancer tumor tissue',
610
+ 'Adipose tissue (Fat) or connective stroma',
611
+ 'Preparation artifact or empty area'
612
+ ]
613
+ },
614
+ 'Ophthalmology': {
615
+ 'domain_prompt': 'Fundus photography (Retina)',
616
+ 'specific_labels': [
617
+ 'Normal retina, healthy macula and optic disc',
618
+ 'Diabetic retinopathy (hemorrhages, exudates, aneurysms)',
619
+ 'Glaucoma (optic disc cupping)',
620
+ 'Macular degeneration (drusen or atrophy)'
621
+ ]
622
+ },
623
+ 'Orthopedics': {
624
+ 'domain_prompt': 'Bone X-Ray (Musculoskeletal)',
625
+ 'stage_1_triage': {
626
+ 'prompt': 'Anatomical region identification',
627
+ 'labels': [
628
+ 'Other x-ray view (Chest, Hand, Foot, Pediatric) - OUT OF DISTRIBUTION',
629
+ 'A knee x-ray view (Knee Joint)'
630
+ ]
631
+ },
632
+ 'stage_2_diagnosis': {
633
+ 'prompt': 'Knee Osteoarthritis Severity Assessment',
634
+ 'labels': [
635
+ 'Severe osteoarthritis with bone-on-bone contact and large osteophytes (Grade 4)',
636
+ 'Moderate osteoarthritis with definite joint space narrowing (Grade 2-3)',
637
+ 'Normal knee joint with preserved joint space and no osteophytes (Grade 0-1)',
638
+ 'Total knee arthroplasty (TKA) with metallic implant',
639
+ 'Acute knee fracture or dislocation'
640
+ ]
641
+ }
642
+ }
643
+ }
644
+
645
+ # =========================================================================
646
+ # PYDANTIC MODELS
647
+ # =========================================================================
648
+ class JobStatus(str, Enum):
649
+ PENDING = "pending"
650
+ PROCESSING = "processing"
651
+ COMPLETED = "completed"
652
+ FAILED = "failed"
653
+
654
+ class Job(BaseModel):
655
+ id: str
656
+ status: JobStatus
657
+ result: Optional[Dict[str, Any]] = None
658
+ error: Optional[str] = None
659
+ created_at: float
660
+ storage_path: Optional[str] = None
661
+ encrypted_user: Optional[str] = None
662
+ username: Optional[str] = None # For registry logging
663
+ file_type: Optional[str] = None # DICOM, PNG, JPEG
664
+ start_time_ms: Optional[float] = None # For computation time
665
+
666
+ class Token(BaseModel):
667
+ access_token: str
668
+ token_type: str
669
+
670
+ class TokenData(BaseModel):
671
+ username: Optional[str] = None
672
+
673
+ class User(BaseModel):
674
+ username: str
675
+ email: Optional[str] = None
676
+
677
+ class UserInDB(User):
678
+ hashed_password: str
679
+ security_question: str
680
+ security_answer: str
681
+
682
+ class UserRegister(BaseModel):
683
+ username: str
684
+ password: str
685
+ email: Optional[str] = None
686
+ security_question: str
687
+ security_answer: str
688
+
689
+ class UserResetPassword(BaseModel):
690
+ username: str
691
+ security_answer: str
692
+ new_password: str
693
+
694
+ class FeedbackModel(BaseModel):
695
+ username: str
696
+ rating: int
697
+ comment: str
698
+
699
+ # =========================================================================
700
+ # GLOBAL STATE
701
+ # =========================================================================
702
+ jobs: Dict[str, Job] = {}
703
+ storage_provider = get_storage_provider(os.getenv("STORAGE_MODE", "LOCAL"))
704
+
705
+ # Initialize Database
706
+ database.init_db()
707
+
708
+ # --- SEED DEFAULT USER ---
709
+ # Ensure admin user exists for immediate login
710
+ try:
711
+ if not database.get_user_by_username("admin"):
712
+ logging.info("👤 Creating default admin user...")
713
+ # Hash "secret"
714
+ admin_pw = bcrypt.hashpw(b"secret", bcrypt.gensalt()).decode('utf-8')
715
+ security_ans = bcrypt.hashpw(b"admin", bcrypt.gensalt()).decode('utf-8') # Answer: admin
716
+
717
+ database.create_user({
718
+ "username": "admin",
719
+ "hashed_password": admin_pw,
720
+ "email": "admin@elephmind.com",
721
+ "security_question": "Who is the admin?",
722
+ "security_answer": security_ans
723
+ })
724
+ logging.info("✅ Default Admin Created: admin / secret")
725
+ except Exception as e:
726
+ logging.error(f"Failed to seed admin user: {e}")
727
+
728
+ # =========================================================================
729
+ # AUTHENTICATION HELPERS
730
+ # =========================================================================
731
+ from passlib.context import CryptContext
732
+
733
+ pwd_context = CryptContext(schemes=["argon2", "bcrypt"], deprecated="auto")
734
+
735
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
736
+ """Verify a password against a bcrypt hash using passlib."""
737
+ return pwd_context.verify(plain_password, hashed_password)
738
+
739
+ def get_password_hash(password: str) -> str:
740
+ """Generate bcrypt hash for a password using passlib."""
741
+ return pwd_context.hash(password)
742
+
743
+ def get_user(db, username: str) -> Optional[UserInDB]:
744
+ """Retrieve user from database."""
745
+ user_dict = database.get_user_by_username(username)
746
+ if user_dict:
747
+ return UserInDB(**user_dict)
748
+ return None
749
+
750
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
751
+ """Create a JWT access token."""
752
+ to_encode = data.copy()
753
+ expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
754
+ to_encode.update({"exp": expire})
755
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
756
+
757
+ async def get_current_user(token: str = Depends(oauth2_scheme)) -> UserInDB:
758
+ """Dependency to get the current authenticated user."""
759
+ credentials_exception = HTTPException(
760
+ status_code=status.HTTP_401_UNAUTHORIZED,
761
+ detail="Could not validate credentials",
762
+ headers={"WWW-Authenticate": "Bearer"},
763
+ )
764
+ try:
765
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
766
+ username: str = payload.get("sub")
767
+ if username is None:
768
+ raise credentials_exception
769
+ token_data = TokenData(username=username)
770
+ except JWTError:
771
+ raise credentials_exception
772
+
773
+ user = get_user(None, username=token_data.username)
774
+ if user is None:
775
+ raise credentials_exception
776
+ return user
777
+
778
+ # =========================================================================
779
+ # GRAD-CAM UTILITIES
780
+ # =========================================================================
781
+ class HuggingFaceWeirdCLIPWrapper(nn.Module):
782
+ """Wraps SigLIP to act like a standard classifier for Grad-CAM."""
783
+
784
+ def __init__(self, model, text_input_ids, attention_mask):
785
+ super(HuggingFaceWeirdCLIPWrapper, self).__init__()
786
+ self.model = model
787
+ self.text_input_ids = text_input_ids
788
+ self.attention_mask = attention_mask
789
+
790
+ def forward(self, pixel_values):
791
+ outputs = self.model(
792
+ pixel_values=pixel_values,
793
+ input_ids=self.text_input_ids,
794
+ attention_mask=self.attention_mask
795
+ )
796
+ return outputs.logits_per_image
797
+
798
+ def reshape_transform(tensor, width=32, height=32):
799
+ """Reshape Transformer attention/embeddings for Grad-CAM."""
800
+ num_tokens = tensor.size(1)
801
+ side = int(np.sqrt(num_tokens))
802
+ result = tensor.reshape(tensor.size(0), side, side, tensor.size(2))
803
+ result = result.transpose(2, 3).transpose(1, 2)
804
+ return result
805
+
806
+ # =========================================================================
807
+ # MODEL WRAPPER
808
+ # =========================================================================
809
+ class MedSigClipWrapper:
810
+ """Wrapper for the SigLIP model with medical domain inference."""
811
+
812
+ def __init__(self, model_path: str):
813
+ self.model_path = model_path
814
+ self.processor = None
815
+ self.model = None
816
+ self.loaded = False
817
+ self.load_error = None
818
+
819
+ def load(self):
820
+ """Load the SigLIP model from the specified directory."""
821
+ logger.info(f"Initiating model load from: {self.model_path}")
822
+
823
+ if not os.path.exists(self.model_path):
824
+ self.load_error = f"Model directory not found: {self.model_path}"
825
+ logger.critical(self.load_error)
826
+ return
827
+
828
+ try:
829
+ from transformers import AutoProcessor, AutoModel
830
+ import torch
831
+
832
+ self.processor = AutoProcessor.from_pretrained(self.model_path, local_files_only=True)
833
+ self.model = AutoModel.from_pretrained(self.model_path, local_files_only=True)
834
+ self.model.eval()
835
+
836
+ # Calibrate logit scale for better probability distribution
837
+ if hasattr(self.model, 'logit_scale'):
838
+ with torch.no_grad():
839
+ self.model.logit_scale.data.fill_(3.80666) # ln(45)
840
+
841
+ self.loaded = True
842
+ logger.info("✅ MedSigClip Model Loaded Successfully (448x448 SigLIP architecture)")
843
+ except Exception as e:
844
+ self.load_error = f"Exception during load: {str(e)}"
845
+ logger.error(f"Failed to load model: {str(e)}")
846
+
847
+ def predict(self, image_bytes: bytes) -> Dict[str, Any]:
848
+ """Run hierarchical inference using SigLIP Zero-Shot."""
849
+ if not self.loaded:
850
+ msg = "MedSigClip Model is NOT loaded. Cannot perform inference."
851
+ if self.load_error:
852
+ msg += f" Reason: {self.load_error}"
853
+ raise RuntimeError(msg)
854
+
855
+ logger.info("Starting inference pipeline...")
856
+ start_time = time.time()
857
+
858
+ try:
859
+ from PIL import Image
860
+ import io
861
+ import torch
862
+ import pydicom
863
+
864
+ # Image preprocessing functions
865
+ def process_dicom(file_bytes: bytes) -> Tuple[Image.Image, Dict[str, Any]]:
866
+ """Convert DICOM bytes to PIL Image with tags."""
867
+ ds = pydicom.dcmread(io.BytesIO(file_bytes))
868
+ img = ds.pixel_array.astype(np.float32)
869
+
870
+ # Extract Metadata
871
+ metadata = {
872
+ "patient_id": str(ds.get("PatientID", "N/A")),
873
+ "patient_name": str(ds.get("PatientName", "N/A")),
874
+ "birth_date": str(ds.get("PatientBirthDate", "")),
875
+ "study_date": str(ds.get("StudyDate", "")),
876
+ "modality": str(ds.get("Modality", "UNKNOWN"))
877
+ }
878
+
879
+ if hasattr(ds, 'PhotometricInterpretation') and ds.PhotometricInterpretation == "MONOCHROME1":
880
+ img = img.max() - img
881
+
882
+ # Lung Window: WL=-600, WW=1500
883
+ wl, ww = -600, 1500
884
+ min_val, max_val = wl - ww/2, wl + ww/2
885
+ img = np.clip(img, min_val, max_val)
886
+ img = (img - min_val) / (max_val - min_val)
887
+ img = (img * 255).astype(np.uint8)
888
+
889
+ return Image.fromarray(img).convert("RGB"), metadata
890
+
891
+ def process_standard_image(image_bytes: bytes) -> Image.Image:
892
+ """Process standard images (PNG/JPG) - SIMPLIFIED like Colab.
893
+ Just load the image as RGB without aggressive preprocessing."""
894
+ nparr = np.frombuffer(image_bytes, np.uint8)
895
+ img_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
896
+
897
+ if img_cv is None:
898
+ raise ValueError("Could not decode image")
899
+
900
+ # Convert BGR to RGB (OpenCV uses BGR)
901
+ img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
902
+
903
+ return Image.fromarray(img_rgb)
904
+
905
+ # Detect image format
906
+ header = image_bytes[:32]
907
+ is_png = header.startswith(b'\x89PNG\r\n\x1a\n')
908
+ is_jpeg = header.startswith(b'\xff\xd8\xff')
909
+
910
+ image = None
911
+ dicom_metadata = None
912
+
913
+ if is_png or is_jpeg:
914
+ try:
915
+ image = process_standard_image(image_bytes)
916
+ logger.info(f"Processed as {'PNG' if is_png else 'JPEG'}")
917
+ except Exception as e:
918
+ raise ValueError(f"Corrupt Image File: {str(e)}")
919
+
920
+ if image is None:
921
+ try:
922
+ image, dicom_metadata = process_dicom(image_bytes)
923
+ logger.info("Processed as DICOM")
924
+ except Exception:
925
+ try:
926
+ image = process_standard_image(image_bytes)
927
+ except Exception as e:
928
+ raise ValueError(f"Unknown image format: {str(e)}")
929
+
930
+ # =========================================================
931
+ # ADAPTIVE PREPROCESSING - DISABLED to match Colab behavior
932
+ # The model was trained on raw images, not preprocessed ones
933
+ # =========================================================
934
+ preprocessing_log = {"message": "Preprocessing disabled for accuracy", "transformation_count": 0}
935
+ # NOTE: Uncomment below to re-enable if needed
936
+ # try:
937
+ # import io as io_module
938
+ # buffer = io_module.BytesIO()
939
+ # image.save(buffer, format='PNG')
940
+ # image_bytes_for_preprocessing = buffer.getvalue()
941
+ # image, preprocessing_log = adaptive_preprocessing(image_bytes_for_preprocessing)
942
+ # logger.info(f"🔧 Adaptive preprocessing applied: {preprocessing_log.get('transformation_count', 0)} transformations")
943
+ # except Exception as e_preproc:
944
+ # logger.warning(f"Adaptive preprocessing skipped: {e_preproc}")
945
+
946
+ # STEP 1: DOMAIN IDENTIFICATION
947
+ domain_keys = list(MEDICAL_DOMAINS.keys())
948
+ domain_prompts = [d['domain_prompt'] for d in MEDICAL_DOMAINS.values()]
949
+
950
+ inputs_domain = self.processor(
951
+ text=domain_prompts,
952
+ images=image,
953
+ padding="max_length",
954
+ return_tensors="pt"
955
+ )
956
+
957
+ with torch.no_grad():
958
+ outputs_domain = self.model(**inputs_domain)
959
+
960
+ probs_domain = torch.softmax(outputs_domain.logits_per_image, dim=1)[0]
961
+ best_domain_idx = torch.argmax(probs_domain).item()
962
+ best_domain_key = domain_keys[best_domain_idx]
963
+ best_domain_prob = float(probs_domain[best_domain_idx] * 100)
964
+
965
+ logger.info(f"Identified Domain: {best_domain_key} ({best_domain_prob:.2f}%)")
966
+
967
+ # STEP 2: SPECIFIC ANALYSIS
968
+ domain_config = MEDICAL_DOMAINS[best_domain_key]
969
+ specific_results = []
970
+
971
+ if 'stage_1_triage' in domain_config:
972
+ # Hierarchical Logic (e.g., Orthopedics)
973
+ logger.info(f"Engaging Level 2 Hierarchical Logic for: {best_domain_key}")
974
+
975
+ triage_labels = domain_config['stage_1_triage']['labels']
976
+ inputs_triage = self.processor(text=triage_labels, images=image, padding="max_length", return_tensors="pt")
977
+
978
+ with torch.no_grad():
979
+ out_triage = self.model(**inputs_triage)
980
+
981
+ probs_triage = torch.softmax(out_triage.logits_per_image, dim=1)[0]
982
+ prob_abnormal = float(probs_triage[-1])
983
+ prob_normal = 1.0 - prob_abnormal
984
+
985
+ logger.info(f"Triage: Normal={prob_normal*100:.2f}%, Abnormal={prob_abnormal*100:.2f}%")
986
+
987
+ if prob_abnormal > prob_normal:
988
+ logger.info("Running Stage 2 Diagnosis...")
989
+ diag_labels = domain_config['stage_2_diagnosis']['labels']
990
+ inputs_diag = self.processor(text=diag_labels, images=image, padding="max_length", return_tensors="pt")
991
+
992
+ with torch.no_grad():
993
+ out_diag = self.model(**inputs_diag)
994
+
995
+ probs_diag = torch.softmax(out_diag.logits_per_image, dim=1)[0]
996
+
997
+ for i, label in enumerate(diag_labels):
998
+ specific_results.append({
999
+ "label": label,
1000
+ "probability": round(float(probs_diag[i] * 100), 2)
1001
+ })
1002
+ else:
1003
+ logger.info("Triage indicates Normal/Healthy. Skipping Stage 2.")
1004
+ else:
1005
+ # Flat Mode (Thoracic, Dermato, etc.)
1006
+ specific_labels_raw = domain_config['specific_labels']
1007
+
1008
+ inputs_specific = self.processor(
1009
+ text=specific_labels_raw,
1010
+ images=image,
1011
+ padding="max_length",
1012
+ return_tensors="pt"
1013
+ )
1014
+
1015
+ with torch.no_grad():
1016
+ outputs_specific = self.model(**inputs_specific)
1017
+
1018
+ probs_specific = torch.softmax(outputs_specific.logits_per_image, dim=1)[0]
1019
+
1020
+ for i, label in enumerate(specific_labels_raw):
1021
+ specific_results.append({
1022
+ "label": label,
1023
+ "probability": round(float(probs_specific[i] * 100), 2)
1024
+ })
1025
+
1026
+ specific_results.sort(key=lambda x: x['probability'], reverse=True)
1027
+
1028
+ # STEP 3: HEATMAP GENERATION (Grad-CAM++)
1029
+ heatmap_base64 = None
1030
+ original_base64 = None
1031
+
1032
+ try:
1033
+ if specific_results:
1034
+ top_label_text = specific_results[0]['label']
1035
+ logger.info(f"Generating Heatmap for: {top_label_text}")
1036
+
1037
+ target_text = [top_label_text]
1038
+ inputs_gradcam = self.processor(
1039
+ text=target_text, images=image, padding="max_length", return_tensors="pt"
1040
+ )
1041
+
1042
+ input_ids = inputs_gradcam.input_ids
1043
+ attention_mask = getattr(inputs_gradcam, 'attention_mask', None)
1044
+
1045
+ model_wrapper_cam = HuggingFaceWeirdCLIPWrapper(
1046
+ self.model, input_ids, attention_mask
1047
+ )
1048
+
1049
+ try:
1050
+ target_layer = self.model.vision_model.post_layernorm
1051
+ target_layers = [target_layer]
1052
+ except AttributeError as e:
1053
+ logger.error(f"Could not find target layer: {e}")
1054
+ raise e
1055
+
1056
+ cam = GradCAMPlusPlus(
1057
+ model=model_wrapper_cam,
1058
+ target_layers=target_layers,
1059
+ reshape_transform=reshape_transform
1060
+ )
1061
+
1062
+ grayscale_cam = cam(input_tensor=inputs_gradcam.pixel_values, targets=None)
1063
+ grayscale_cam = grayscale_cam[0, :]
1064
+
1065
+ # --- FIX: SMOOTHING FOR ORGANIC LOOK ---
1066
+ # ViT attention is blocky by nature. We apply Gaussian Blur to smooth it out.
1067
+ grayscale_cam = cv2.GaussianBlur(grayscale_cam, (13, 13), 0)
1068
+ # ---------------------------------------
1069
+
1070
+ img_tensor = inputs_gradcam.pixel_values[0].detach().cpu().numpy()
1071
+ img_tensor = np.transpose(img_tensor, (1, 2, 0))
1072
+ img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min())
1073
+ img_tensor = np.clip(img_tensor, 0, 1).astype(np.float32)
1074
+
1075
+ visualization = show_cam_on_image(img_tensor, grayscale_cam, use_rgb=True)
1076
+
1077
+ _, buffer = cv2.imencode('.png', cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR))
1078
+ heatmap_base64 = base64.b64encode(buffer).decode('utf-8')
1079
+
1080
+ original_uint8 = (img_tensor * 255).astype(np.uint8)
1081
+ _, buffer_orig = cv2.imencode('.png', cv2.cvtColor(original_uint8, cv2.COLOR_RGB2BGR))
1082
+ original_base64 = base64.b64encode(buffer_orig).decode('utf-8')
1083
+
1084
+ logger.info("✅ Grad-CAM++ Heatmap generated successfully")
1085
+
1086
+ except Exception as e_cam:
1087
+ import traceback
1088
+ logger.error(f"Grad-CAM Generation Failed: {traceback.format_exc()}")
1089
+
1090
+ # FINAL RESULT (Base)
1091
+ result_json = {
1092
+ "domain": {
1093
+ "label": best_domain_key,
1094
+ "description": MEDICAL_DOMAINS[best_domain_key]['domain_prompt'],
1095
+ "probability": round(best_domain_prob, 2)
1096
+ },
1097
+ "specific": specific_results,
1098
+ "heatmap": heatmap_base64,
1099
+ "original_image": original_base64,
1100
+ "preprocessing": preprocessing_log # Algorithm 7 log
1101
+ }
1102
+
1103
+ # =========================================================
1104
+ # APPLY 7 INTELLIGENCE ALGORITHMS
1105
+ # =========================================================
1106
+ logger.info("🧠 Applying Intelligence Algorithms...")
1107
+
1108
+
1109
+ # Convert PIL image to numpy for quality assessment
1110
+ image_array = np.array(image)
1111
+
1112
+ # Get image embedding for similar case detection
1113
+ try:
1114
+ with torch.no_grad():
1115
+ img_inputs = self.processor(images=image, return_tensors="pt")
1116
+ image_embedding = self.model.get_image_features(**img_inputs)
1117
+ image_embedding = image_embedding.cpu().numpy().flatten()
1118
+ except Exception as e_emb:
1119
+ logger.warning(f"Could not extract embedding: {e_emb}")
1120
+ image_embedding = None
1121
+
1122
+ # Enhance result with all algorithms
1123
+ enhanced_result = enhance_analysis_result(
1124
+ base_result=result_json,
1125
+ image_array=image_array,
1126
+ embedding=image_embedding,
1127
+ case_id=str(uuid.uuid4()),
1128
+ patient_info=None # Can be passed from request later
1129
+ )
1130
+
1131
+ # --- MAP TO FRONTEND EXPECTATIONS ---
1132
+ # frontend expects: diagnosis, confidence, productions, quality_metrics, etc.
1133
+
1134
+ # 1. Diagnosis
1135
+ top_finding = enhanced_result['specific'][0] if enhanced_result['specific'] else {"label": "Inconnu", "probability": 0}
1136
+ enhanced_result['diagnosis'] = top_finding['label']
1137
+
1138
+ # 2. Confidence & Calibrated
1139
+ enhanced_result['calibrated_confidence'] = enhanced_result.get('confidence', top_finding['probability'])
1140
+ enhanced_result['confidence'] = top_finding['probability']
1141
+
1142
+ # 3. Processing Time (Real Measurement)
1143
+ enhanced_result['processing_time'] = round(time.time() - start_time, 3)
1144
+
1145
+ # 4. Predictions (Alias for specific)
1146
+ enhanced_result['predictions'] = [
1147
+ {"name": item['label'], "probability": item['probability']}
1148
+ for item in enhanced_result['specific']
1149
+ ]
1150
+
1151
+ # 5. Quality Metrics (Flatten structure)
1152
+ if 'image_quality' in enhanced_result:
1153
+ enhanced_result['quality_score'] = enhanced_result['image_quality']['quality_score']
1154
+ enhanced_result['quality_metrics'] = enhanced_result['image_quality']['metrics']
1155
+
1156
+ # 6. Priority
1157
+ # If priority is a dict (from new algo), extract just the level/score for simple display, or keep object
1158
+ # Frontend expects string 'priority' sometimes, or maybe object. Let's provide string for badge.
1159
+ if isinstance(enhanced_result.get('priority'), str):
1160
+ pass
1161
+ elif isinstance(enhanced_result.get('priority'), dict):
1162
+ # Flatten for frontend simple badge
1163
+ enhanced_result['priority'] = enhanced_result['priority'].get('level', 'Normale')
1164
+
1165
+ # 7. DICOM Metadata (if available)
1166
+ if dicom_metadata:
1167
+ enhanced_result['patient_metadata'] = dicom_metadata
1168
+
1169
+ logger.info("✅ Intelligence Algorithms applied successfully")
1170
+
1171
+ return localize_result(enhanced_result)
1172
+
1173
+ except Exception as e:
1174
+ logger.error(f"Inference Error: {str(e)}")
1175
+ raise e
1176
+
1177
+ # =========================================================================
1178
+ # GLOBAL MODEL INSTANCE
1179
+ # =========================================================================
1180
+ model_wrapper: Optional[MedSigClipWrapper] = None
1181
+
1182
+ # =========================================================================
1183
+ # FASTAPI LIFECYCLE
1184
+ # =========================================================================
1185
+ @asynccontextmanager
1186
+ async def lifespan(app: FastAPI):
1187
+ global model_wrapper, MODEL_DIR # CRITICAL: Use global variables
1188
+ database.init_db()
1189
+ database.init_analysis_registry()
1190
+
1191
+ # Get model path (downloads from HuggingFace Hub if needed)
1192
+ MODEL_DIR = get_model_path()
1193
+
1194
+ model_wrapper = MedSigClipWrapper(MODEL_DIR)
1195
+ model_wrapper.load()
1196
+ logger.info("ElephMind Backend Started")
1197
+ yield
1198
+ logger.info("ElephMind Backend Shutting Down")
1199
+
1200
+ app = FastAPI(
1201
+ lifespan=lifespan,
1202
+ title="ElephMind Medical AI API",
1203
+ version="2.0.0",
1204
+ description="Medical image analysis powered by SigLIP"
1205
+ )
1206
+
1207
+ # CORS Middleware with configurable origins
1208
+ app.add_middleware(
1209
+ CORSMiddleware,
1210
+ allow_origins=["*"], # Allow all origins to fix "Failed to fetch" for user
1211
+ allow_credentials=True,
1212
+ allow_methods=["*"],
1213
+ allow_headers=["*"],
1214
+ )
1215
+
1216
+ @app.middleware("http")
1217
+ async def limit_concurrency(request: Request, call_next):
1218
+ """Limit concurrent requests to MAX_CONCURRENT_USERS."""
1219
+ if request.url.path == "/health" or request.method == "OPTIONS":
1220
+ return await call_next(request)
1221
+
1222
+ if concurrency_semaphore.locked():
1223
+ logger.warning(f"Concurrency limit ({MAX_CONCURRENT_USERS}) reached. Request queued.")
1224
+
1225
+ async with concurrency_semaphore:
1226
+ return await call_next(request)
1227
+
1228
+ # =========================================================================
1229
+ # BACKGROUND WORKER
1230
+ # =========================================================================
1231
+ async def process_analysis(job_id: str, image_bytes: bytes):
1232
+ """Background task to run inference and log to registry."""
1233
+ job = jobs.get(job_id)
1234
+ if not job:
1235
+ return
1236
+
1237
+ logger.info(f"Processing Job {job_id}")
1238
+ job.status = JobStatus.PROCESSING
1239
+ start_time = time.time()
1240
+
1241
+ try:
1242
+ if not model_wrapper:
1243
+ raise RuntimeError("Model wrapper not initialized.")
1244
+
1245
+ loop = asyncio.get_event_loop()
1246
+ result = await loop.run_in_executor(None, model_wrapper.predict, image_bytes)
1247
+
1248
+ job.result = result
1249
+ job.status = JobStatus.COMPLETED
1250
+
1251
+ # Calculate computation time
1252
+ computation_time_ms = int((time.time() - start_time) * 1000)
1253
+
1254
+ # Log to registry (REAL DATA)
1255
+ if job.username and result:
1256
+ domain = result.get('domain', {}).get('label', 'Unknown')
1257
+ top_diag = result.get('specific', [{}])[0].get('label', 'Unknown') if result.get('specific') else 'Unknown'
1258
+ confidence = result.get('specific', [{}])[0].get('probability', 0) if result.get('specific') else 0
1259
+ priority = result.get('priority', 'Normale')
1260
+
1261
+ database.log_analysis(
1262
+ username=job.username,
1263
+ domain=domain,
1264
+ top_diagnosis=top_diag,
1265
+ confidence=confidence,
1266
+ priority=priority,
1267
+ computation_time_ms=computation_time_ms,
1268
+ file_type=job.file_type or 'Unknown'
1269
+ )
1270
+ logger.info(f"✅ Job {job_id} logged to registry")
1271
+
1272
+ logger.info(f"✅ Job {job_id} completed in {computation_time_ms}ms")
1273
+
1274
+ except Exception as e:
1275
+ logger.error(f"❌ Job {job_id} failed: {str(e)}")
1276
+ job.error = str(e)
1277
+ job.status = JobStatus.FAILED
1278
+
1279
+ # =========================================================================
1280
+ # API ENDPOINTS
1281
+ # =========================================================================
1282
+
1283
+ # --- Authentication ---
1284
+ @app.post("/token", response_model=Token)
1285
+ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
1286
+ """Authenticate user and return JWT token."""
1287
+ user = database.get_user_by_username(form_data.username)
1288
+ if not user or not verify_password(form_data.password, user['hashed_password']):
1289
+ raise HTTPException(
1290
+ status_code=status.HTTP_401_UNAUTHORIZED,
1291
+ detail="Incorrect username or password",
1292
+ headers={"WWW-Authenticate": "Bearer"},
1293
+ )
1294
+
1295
+ access_token = create_access_token(
1296
+ data={"sub": user['username']},
1297
+ expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
1298
+ )
1299
+ return {"access_token": access_token, "token_type": "bearer"}
1300
+
1301
+ @app.post("/register", status_code=status.HTTP_201_CREATED)
1302
+ async def register_user(user: UserRegister):
1303
+ """Register a new user."""
1304
+ hashed_pw = get_password_hash(user.password)
1305
+ # Hash security answer too for extra security
1306
+ hashed_security_answer = get_password_hash(user.security_answer.strip().lower())
1307
+
1308
+ user_data = {
1309
+ "username": user.username,
1310
+ "hashed_password": hashed_pw,
1311
+ "email": user.email,
1312
+ "security_question": user.security_question,
1313
+ "security_answer": hashed_security_answer
1314
+ }
1315
+ success = database.create_user(user_data)
1316
+ if not success:
1317
+ raise HTTPException(status_code=400, detail="Username already exists")
1318
+ return {"message": "User created successfully"}
1319
+
1320
+ @app.get("/recover/{username}")
1321
+ async def get_security_question(username: str):
1322
+ """Get security question for password recovery."""
1323
+ user = database.get_user_by_username(username)
1324
+ if not user:
1325
+ raise HTTPException(status_code=404, detail="User not found")
1326
+ return {"question": user['security_question']}
1327
+
1328
+ @app.post("/recover/reset")
1329
+ async def reset_password(data: UserResetPassword):
1330
+ """Reset password using security question."""
1331
+ user = database.get_user_by_username(data.username)
1332
+ if not user:
1333
+ raise HTTPException(status_code=404, detail="User not found")
1334
+
1335
+ # Verify security answer (hashed comparison)
1336
+ if not verify_password(data.security_answer.strip().lower(), user['security_answer']):
1337
+ raise HTTPException(status_code=400, detail="Incorrect security answer")
1338
+
1339
+ new_hashed_pw = get_password_hash(data.new_password)
1340
+ database.update_password(data.username, new_hashed_pw)
1341
+ return {"message": "Password reset successfully"}
1342
+
1343
+ # --- Dashboard Analytics (REAL DATA ONLY) ---
1344
+ @app.get("/api/dashboard/stats")
1345
+ async def get_dashboard_statistics(current_user: User = Depends(get_current_user)):
1346
+ """
1347
+ Get real dashboard statistics for the authenticated user.
1348
+ Returns zeros if no analyses have been performed. NO FAKE DATA.
1349
+ """
1350
+ stats = database.get_dashboard_stats(current_user.username)
1351
+ recent = database.get_recent_analyses(current_user.username, limit=10)
1352
+
1353
+ return {
1354
+ **stats,
1355
+ "recent_analyses": recent
1356
+ }
1357
+
1358
+ @app.post("/feedback")
1359
+ async def submit_feedback(feedback: FeedbackModel):
1360
+ """Submit user feedback."""
1361
+ database.add_feedback(feedback.username, feedback.rating, feedback.comment)
1362
+ return {"message": "Feedback received"}
1363
+
1364
+ # --- Medical Analysis ---
1365
+ @app.post("/analyze", response_model=Dict[str, str])
1366
+ async def analyze_image(
1367
+ background_tasks: BackgroundTasks,
1368
+ file: UploadFile = File(...),
1369
+ current_user: User = Depends(get_current_user)
1370
+ ):
1371
+ """
1372
+ Analyze a medical image.
1373
+
1374
+ - **Requires authentication**
1375
+ - Accepts DICOM (.dcm) and standard images (PNG, JPEG)
1376
+ - Returns a job_id for polling results
1377
+ """
1378
+ allowed_types = ['image/', 'application/dicom', 'application/octet-stream']
1379
+ if not any(file.content_type.startswith(t) for t in allowed_types):
1380
+ logger.warning(f"Rejected file type: {file.content_type}")
1381
+ raise HTTPException(status_code=400, detail=f"Invalid file type: {file.content_type}")
1382
+
1383
+ job_id = str(uuid.uuid4())
1384
+ logger.info(f"Received Analysis Request. Job ID: {job_id}")
1385
+
1386
+ enc_user = encryption.encrypt_data(current_user.username)
1387
+ image_bytes = await file.read()
1388
+
1389
+ try:
1390
+ storage_path = storage_provider.save_file(image_bytes, file.filename)
1391
+ except Exception as e:
1392
+ logger.error(f"Storage Failed: {e}")
1393
+ storage_path = "failed_storage"
1394
+
1395
+ # Determine file type for registry
1396
+ file_ext = file.filename.split('.')[-1].upper() if file.filename else 'UNKNOWN'
1397
+ if file_ext == 'DCM':
1398
+ file_type = 'DICOM'
1399
+ elif file_ext in ['PNG', 'JPG', 'JPEG']:
1400
+ file_type = file_ext
1401
+ else:
1402
+ file_type = 'OTHER'
1403
+
1404
+ jobs[job_id] = Job(
1405
+ id=job_id,
1406
+ status=JobStatus.PENDING,
1407
+ created_at=time.time(),
1408
+ encrypted_user=enc_user,
1409
+ storage_path=storage_path,
1410
+ username=current_user.username, # For registry logging
1411
+ file_type=file_type # For registry logging
1412
+ )
1413
+
1414
+ background_tasks.add_task(process_analysis, job_id, image_bytes)
1415
+
1416
+ return {"task_id": job_id, "status": "pending"}
1417
+
1418
+ @app.get("/result/{task_id}")
1419
+ async def get_result(task_id: str, current_user: User = Depends(get_current_user)):
1420
+ """
1421
+ Get analysis result by task ID.
1422
+
1423
+ - **Requires authentication**
1424
+ - Returns job status and results when complete
1425
+ """
1426
+ job = jobs.get(task_id)
1427
+ if not job:
1428
+ logger.warning(f"Job not found: {task_id}")
1429
+ raise HTTPException(status_code=404, detail="Job not found")
1430
+
1431
+ # Verify ownership: decrypt stored user and compare
1432
+ if job.encrypted_user:
1433
+ stored_user = encryption.decrypt_data(job.encrypted_user)
1434
+ if stored_user != current_user.username:
1435
+ logger.warning(f"Unauthorized access attempt to job {task_id} by {current_user.username}")
1436
+ raise HTTPException(status_code=403, detail="Access denied")
1437
+
1438
+ logger.info(f"Polling Job {task_id}: Status={job.status}")
1439
+ return job
1440
+
1441
+ @app.get("/health")
1442
+ def health_check():
1443
+ """Health check endpoint."""
1444
+ loaded = model_wrapper.loaded if model_wrapper else False
1445
+ return {
1446
+ "status": "running",
1447
+ "model_loaded": loaded,
1448
+ "version": "2.0.0"
1449
+ }
1450
+
1451
+ # --- DASHBOARD ENDPOINTS ---
1452
+
1453
+ @app.get("/api/dashboard/stats")
1454
+ async def get_dashboard_stats_endpoint(current_user: User = Depends(get_current_user)):
1455
+ """Get real dashboard statistics for the authenticated user."""
1456
+ try:
1457
+ stats = database.get_dashboard_stats(current_user.username)
1458
+ recent = database.get_recent_analyses(current_user.username, limit=5)
1459
+ # Combine
1460
+ return {
1461
+ **stats,
1462
+ "recent_analyses": recent
1463
+ }
1464
+ except Exception as e:
1465
+ logger.error(f"Error fetching dashboard stats: {e}")
1466
+ raise HTTPException(status_code=500, detail=str(e))
1467
+
1468
+ # =========================================================================
1469
+ # MAIN ENTRY POINT
1470
+ # =========================================================================
1471
+ if __name__ == "__main__":
1472
+ # Initialize DB tables including registry
1473
+ database.init_db()
1474
+ database.init_analysis_registry()
1475
+
1476
+ host = os.getenv("SERVER_HOST", "0.0.0.0")
1477
+ port = int(os.getenv("SERVER_PORT", "8022"))
1478
+ uvicorn.run(app, host=host, port=port)
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ requests
5
+ transformers
6
+ torch
7
+ Pillow
8
+ sentencepiece
9
+ pydicom
10
+ numpy
11
+ grad-cam
12
+ python-jose[cryptography]
13
+ passlib
14
+ argon2-cffi
15
+ bcrypt==4.0.1
16
+ cryptography
17
+ python-dotenv
18
+ opencv-python
19
+ python-swiftclient
20
+ protobuf
21
+ huggingface_hub
scripts/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ElephMind Utility Scripts
2
+
3
+ This directory contains maintenance and debug scripts for the ElephMind backend.
4
+
5
+ ## How to Run
6
+
7
+ Because these scripts import modules from the parent `server/` directory, you must run them with the parent directory in your `PYTHONPATH`.
8
+
9
+ **Windows (PowerShell):**
10
+ ```powershell
11
+ $env:PYTHONPATH=".."; python init_admin.py
12
+ ```
13
+
14
+ **Linux/Mac:**
15
+ ```bash
16
+ PYTHONPATH=.. python init_admin.py
17
+ ```
18
+
19
+ ## Available Scripts
20
+
21
+ - **`init_admin.py`**: Creates the initial 'admin' user with secure password hashing.
22
+ - **`verify_admin.py`**: Checks if the admin user exists in the database.
23
+ - **`test_auth.py`**: Unit tests for the authentication logic.
24
+ - **`debug_inference.py`**: Tests the ML model with a dummy image.
25
+ - **`inspect_model.py`**: Prints details about the loaded PyTorch model.
scripts/debug_inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModel
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+
6
+ # Configuration
7
+ MODEL_DIR = r"D:\oeil d'elephant"
8
+
9
+ def test_inference():
10
+ print(f"Loading model from {MODEL_DIR}...")
11
+ try:
12
+ model = AutoModel.from_pretrained(MODEL_DIR, local_files_only=True)
13
+ processor = AutoProcessor.from_pretrained(MODEL_DIR, local_files_only=True)
14
+ model.eval()
15
+
16
+ if hasattr(model, 'logit_scale'):
17
+ with torch.no_grad():
18
+ model.logit_scale.data.fill_(4.60517) # exp(4.6) = 100
19
+
20
+ print("Model loaded.")
21
+ except Exception as e:
22
+ print(f"Failed to load model: {e}")
23
+ return
24
+
25
+ # Synthetic Chest X-ray
26
+ image = Image.new('RGB', (448, 448), color=(0, 0, 0))
27
+ draw = ImageDraw.Draw(image)
28
+ draw.ellipse([100, 100, 200, 350], fill=(200, 200, 200))
29
+ draw.ellipse([248, 100, 348, 350], fill=(200, 200, 200)) # Lungs
30
+
31
+ # Simple Prompts Hypothesis
32
+ prompts = [
33
+ 'Os',
34
+ 'Poumons',
35
+ 'Peau',
36
+ 'Oeil',
37
+ 'Sein',
38
+ 'Tissu'
39
+ ]
40
+
41
+ # Also test slightly descriptive
42
+ prompts_v2 = [
43
+ 'Radiographie Os',
44
+ 'Radiographie Poumons',
45
+ 'Photo Peau',
46
+ 'Fond d\'oeil',
47
+ 'Mammographie Sein',
48
+ 'Microscope Tissu'
49
+ ]
50
+
51
+ print("\nTesting Simple Prompts on Synthetic Chest X-ray:")
52
+
53
+ for p_set in [prompts, prompts_v2]:
54
+ with torch.no_grad():
55
+ inputs = processor(text=p_set, images=image, padding="max_length", return_tensors="pt")
56
+ outputs = model(**inputs)
57
+ logits = outputs.logits_per_image
58
+ probs = torch.sigmoid(logits)[0]
59
+
60
+ # Also calculate Softmax
61
+ probs_softmax = torch.softmax(logits, dim=1)[0]
62
+
63
+ for i, prompt in enumerate(p_set):
64
+ l = logits[0][i].item()
65
+ p_sig = probs[i].item()
66
+ p_soft = probs_softmax[i].item()
67
+ print(f"Prompt: '{prompt:<20}' | Logit: {l:.4f} | Sigmoid: {p_sig*100:.6f}% | Softmax: {p_soft*100:.2f}%")
68
+ print("-" * 60)
69
+
70
+ if __name__ == "__main__":
71
+ test_inference()
scripts/debug_pathology.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModel
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+
6
+ # Configuration
7
+ MODEL_DIR = r"D:\oeil d'elephant"
8
+
9
+ def test_inference():
10
+ print(f"Loading model from {MODEL_DIR}...")
11
+ try:
12
+ model = AutoModel.from_pretrained(MODEL_DIR, local_files_only=True)
13
+ processor = AutoProcessor.from_pretrained(MODEL_DIR, local_files_only=True)
14
+ model.eval()
15
+
16
+ # Apply fix
17
+ if hasattr(model, 'logit_scale'):
18
+ with torch.no_grad():
19
+ model.logit_scale.data.fill_(4.60517)
20
+
21
+ print("Model loaded.")
22
+ except Exception as e:
23
+ print(f"Failed to load model: {e}")
24
+ return
25
+
26
+ # Synthetic Pneumonia X-ray
27
+ # Two lungs, one with a big white consolidation
28
+ image = Image.new('RGB', (448, 448), color=(0, 0, 0))
29
+ draw = ImageDraw.Draw(image)
30
+ draw.ellipse([100, 100, 200, 350], fill=(100, 100, 100)) # Left lung (clearer)
31
+ draw.ellipse([248, 100, 348, 350], fill=(200, 200, 200)) # Right lung (consolidated/white)
32
+
33
+ # Check "Thoracic" specific labels
34
+ labels = [
35
+ 'Cardiomédiastin élargi', 'Cardiomégalie', 'Opacité pulmonaire',
36
+ 'Lésion pulmonaire', 'Consolidation', 'Œdème', 'Pneumonie',
37
+ 'Atelectasis', 'Pneumothorax', 'Effusion pleurale', 'Pleural Autre'
38
+ ]
39
+
40
+ # Try simplified versions too
41
+ simple_labels = [
42
+ 'Coeur', 'Gros coeur', 'Opacité',
43
+ 'Lésion', 'Blanc', 'Eau', 'Infection',
44
+ 'Ecrasé', 'Air', 'Liquide', 'Autre'
45
+ ]
46
+
47
+ print("\nTesting Pathology Prompts:")
48
+
49
+ with torch.no_grad():
50
+ inputs = processor(text=labels, images=image, padding="max_length", return_tensors="pt")
51
+ outputs = model(**inputs)
52
+ logits = outputs.logits_per_image
53
+ probs = torch.sigmoid(logits)[0]
54
+
55
+ print("\nOriginal Labels:")
56
+ for i, label in enumerate(labels):
57
+ print(f"'{label}': Logit {logits[0][i]:.4f} | Prob {probs[i]:.6f}")
58
+
59
+ # Test Simple
60
+ inputs_simple = processor(text=simple_labels, images=image, padding="max_length", return_tensors="pt")
61
+ outputs_simple = model(**inputs_simple)
62
+ logits_simple = outputs_simple.logits_per_image
63
+ probs_simple = torch.sigmoid(logits_simple)[0]
64
+
65
+ print("\nSimple Labels:")
66
+ for i, label in enumerate(simple_labels):
67
+ print(f"'{label}': Logit {logits_simple[0][i]:.4f} | Prob {probs_simple[0][i]:.6f}")
68
+
69
+ if __name__ == "__main__":
70
+ test_inference()
scripts/init_admin.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
4
+ import database
5
+ from main import get_password_hash
6
+
7
+ def create_admin():
8
+ database.init_db()
9
+ if database.get_user_by_username("admin"):
10
+ print("Admin already exists.")
11
+ return
12
+
13
+ admin_data = {
14
+ "username": "admin",
15
+ "hashed_password": get_password_hash("password123"),
16
+ "email": "admin@elephmind.com",
17
+ "security_question": "Quel est votre animal totem ?",
18
+ "security_answer": get_password_hash("elephant")
19
+ }
20
+
21
+ if database.create_user(admin_data):
22
+ print("Admin user created successfully. (Login: admin / password123)")
23
+ else:
24
+ print("Failed to create admin user.")
25
+
26
+ if __name__ == "__main__":
27
+ create_admin()
scripts/inspect_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ MODEL_DIR = r"D:\oeil d'elephant"
5
+
6
+ def inspect():
7
+ files = ["config.json", "preprocessor_config.json", "tokenizer_config.json"]
8
+
9
+ for f in files:
10
+ path = os.path.join(MODEL_DIR, f)
11
+ print(f"\n--- {f} ---")
12
+ if os.path.exists(path):
13
+ try:
14
+ with open(path, 'r', encoding='utf-8') as file:
15
+ content = json.load(file)
16
+ # Print summary to avoid huge output
17
+ if f == "config.json":
18
+ print(json.dumps({k:v for k,v in content.items() if k in ['architectures', 'model_type', 'logit_scale_init_value', 'vision_config', 'text_config']}, indent=2))
19
+ elif f == "preprocessor_config.json":
20
+ print(json.dumps(content, indent=2))
21
+ else:
22
+ print(json.dumps(content, indent=2))
23
+ except Exception as e:
24
+ print(f"Error reading {f}: {e}")
25
+ else:
26
+ print("File not found.")
27
+
28
+ if __name__ == "__main__":
29
+ inspect()
scripts/test_auth.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import sys
3
+
4
+ BASE_URL = "http://127.0.0.1:8022"
5
+
6
+ def test_health():
7
+ print(f"Testing Health Check at {BASE_URL}/health...")
8
+ try:
9
+ r = requests.get(f"{BASE_URL}/health")
10
+ if r.status_code == 200:
11
+ print("✅ Health Check Passed")
12
+ return True
13
+ except Exception as e:
14
+ print(f"❌ Health Check Failed: {e}")
15
+ return False
16
+
17
+ def test_auth():
18
+ print("Testing Authentication...")
19
+
20
+ # 1. Try to access protected route without token
21
+ try:
22
+ r = requests.post(f"{BASE_URL}/analyze")
23
+ if r.status_code == 401:
24
+ print("✅ Protected Endpoint correctly rejected unauthorized request (401)")
25
+ else:
26
+ print(f"❌ Protected Endpoint Failed: Expected 401, got {r.status_code}")
27
+ return False
28
+
29
+ # 2. Login to get token
30
+ payload = {"username": "admin", "password": "secret"}
31
+ r = requests.post(f"{BASE_URL}/token", data=payload)
32
+ if r.status_code == 200:
33
+ token = r.json().get("access_token")
34
+ if token:
35
+ print("✅ Login Successful. Token received.")
36
+ else:
37
+ print("❌ Login Failed: No token in response")
38
+ return False
39
+ else:
40
+ print(f"❌ Login Failed: {r.status_code} - {r.text}")
41
+ return False
42
+
43
+ # 3. Access protected route WITH token (Should fail on 422 Validation 'Field required' for file, NOT 401)
44
+ headers = {"Authorization": f"Bearer {token}"}
45
+ # We don't send file, expecting 422 Unprocessable Entity (Missing File), which means Auth passed!
46
+ r = requests.post(f"{BASE_URL}/analyze", headers=headers)
47
+ if r.status_code == 422:
48
+ print("✅ Protected Endpoint correctly accepted token (Got 422 for missing file, not 401)")
49
+ return True
50
+ elif r.status_code == 401:
51
+ print("❌ Protected Endpoint rejected valid token (401)")
52
+ return False
53
+ else:
54
+ print(f"⚠️ Unexpected status with token: {r.status_code}")
55
+ return True # Acceptable for now
56
+
57
+ except Exception as e:
58
+ print(f"❌ Test Exception: {e}")
59
+ return False
60
+
61
+ if __name__ == "__main__":
62
+ if test_health():
63
+ test_auth()
scripts/verify_admin.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import sys
4
+ import os
5
+
6
+ # Add server directory to path to import database
7
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+ import database
10
+
11
+ def check_admin():
12
+ print(f"Checking database: {database.DB_NAME}")
13
+
14
+ # Initialize DB if tables missing (which seems to be the case in this context)
15
+ database.init_db()
16
+
17
+ try:
18
+ user = database.get_user_by_username("admin")
19
+ if user:
20
+ print("USER 'admin' FOUND.")
21
+ print(f" ID: {user['id']}")
22
+ print(f" Email: {user['email']}")
23
+ else:
24
+ print("USER 'admin' NOT FOUND.")
25
+ except Exception as e:
26
+ print(f"Error querying database: {e}")
27
+
28
+ if __name__ == "__main__":
29
+ check_admin()
30
+
secret.key ADDED
@@ -0,0 +1 @@
 
 
1
+ 6cfBgfzHb12RD2eW_9QxGrpoDdScGYoqpV3MYvz96LE=
storage.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import abc
3
+ from datetime import datetime
4
+
5
+ class StorageProvider(abc.ABC):
6
+ @abc.abstractmethod
7
+ def save_file(self, file_bytes: bytes, filename: str) -> str:
8
+ pass
9
+
10
+ @abc.abstractmethod
11
+ def get_file(self, filename: str) -> bytes:
12
+ pass
13
+
14
+ class LocalStorage(StorageProvider):
15
+ def __init__(self, base_dir="data_storage"):
16
+ self.base_dir = base_dir
17
+ os.makedirs(base_dir, exist_ok=True)
18
+
19
+ def save_file(self, file_bytes: bytes, filename: str) -> str:
20
+ # Prepend timestamp to avoid collision
21
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
22
+ safe_name = f"{ts}_{filename}"
23
+ path = os.path.join(self.base_dir, safe_name)
24
+ with open(path, "wb") as f:
25
+ f.write(file_bytes)
26
+ return path
27
+
28
+ def get_file(self, filename: str) -> bytes:
29
+ path = os.path.join(self.base_dir, filename)
30
+ if not os.path.exists(path):
31
+ return None
32
+ with open(path, "rb") as f:
33
+ return f.read()
34
+
35
+ class SwiftStorage(StorageProvider):
36
+ """
37
+ OpenStack Swift Storage Provider.
38
+ Requires python-swiftclient installed.
39
+ """
40
+ def __init__(self, auth_url, username, password, project_name, container_name="elephmind_images"):
41
+ # Import here to avoid error on Windows if not installed
42
+ try:
43
+ from swiftclient import Connection
44
+ except ImportError:
45
+ raise ImportError("python-swiftclient not installed!")
46
+
47
+ self.container_name = container_name
48
+ self.conn = Connection(
49
+ authurl=auth_url,
50
+ user=username,
51
+ key=password,
52
+ tenant_name=project_name,
53
+ auth_version='3',
54
+ os_options={'user_domain_name': 'Default', 'project_domain_name': 'Default'}
55
+ )
56
+ # Ensure container exists
57
+ try:
58
+ self.conn.put_container(self.container_name)
59
+ except Exception as e:
60
+ print(f"Swift Connection Error: {e}")
61
+
62
+ def save_file(self, file_bytes: bytes, filename: str) -> str:
63
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
64
+ safe_name = f"{ts}_{filename}"
65
+ self.conn.put_object(
66
+ self.container_name,
67
+ safe_name,
68
+ contents=file_bytes,
69
+ content_type='application/octet-stream'
70
+ )
71
+ return f"swift://{self.container_name}/{safe_name}"
72
+
73
+ def get_file(self, filename: str) -> bytes:
74
+ # filename could be safe_name
75
+ # logic to extract key if needed
76
+ try:
77
+ _, obj = self.conn.get_object(self.container_name, filename)
78
+ return obj
79
+ except Exception:
80
+ return None
81
+
82
+ # Factory
83
+ def get_storage_provider(config_mode="LOCAL"):
84
+ if config_mode == "OPENSTACK":
85
+ return SwiftStorage(
86
+ auth_url=os.getenv("OS_AUTH_URL"),
87
+ username=os.getenv("OS_USERNAME"),
88
+ password=os.getenv("OS_PASSWORD"),
89
+ project_name=os.getenv("OS_PROJECT_NAME")
90
+ )
91
+ else:
92
+ return LocalStorage()
upload_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # upload_model.py - Upload model to Hugging Face Hub
2
+ from huggingface_hub import upload_folder
3
+ import os
4
+
5
+ model_path = os.path.join("models", "oeil d'elephant")
6
+ print(f"Uploading from: {model_path}")
7
+ print(f"Path exists: {os.path.exists(model_path)}")
8
+
9
+ if os.path.exists(model_path):
10
+ print("Starting upload... (this may take a while for 3.5GB)")
11
+ upload_folder(
12
+ folder_path=model_path,
13
+ repo_id="issoufzousko07/medsigclip-model",
14
+ repo_type="model"
15
+ )
16
+ print("Upload complete!")
17
+ else:
18
+ print(f"ERROR: Path not found: {model_path}")
19
+ print("Available in models/:")
20
+ if os.path.exists("models"):
21
+ print(os.listdir("models"))
upload_space.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # upload_space.py - Upload code to HuggingFace Space (excluding large model files)
2
+ from huggingface_hub import upload_folder
3
+ import os
4
+
5
+ print("Uploading ElephMind API to HuggingFace Space...")
6
+ print("(Model will be downloaded from Hub at runtime)")
7
+
8
+ upload_folder(
9
+ folder_path=".",
10
+ repo_id="issoufzousko07/elephmind-api",
11
+ repo_type="space",
12
+ ignore_patterns=["models/*", "*.pyc", "__pycache__", "*.db", "storage/*", "upload_model.py", "upload_space.py"]
13
+ )
14
+
15
+ print("✅ Upload complete!")
16
+ print("Your Space should start building at: https://huggingface.co/spaces/issoufzousko07/elephmind-api")