Upload folder using huggingface_hub

#13
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ models/
2
+ data_storage/
3
+ elephmind.db
4
+ .env
5
+ venv/
6
+ __pycache__/
Dockerfile CHANGED
@@ -1,38 +1,38 @@
1
- # Hugging Face Spaces Docker Configuration
2
- FROM python:3.10-slim
3
-
4
- # Create non-root user (required by HuggingFace)
5
- RUN useradd -m -u 1000 user
6
- ENV HOME=/home/user
7
- ENV PATH="/home/user/.local/bin:$PATH"
8
-
9
- WORKDIR /app
10
-
11
- # Install system dependencies as root
12
- RUN apt-get update && apt-get install -y --no-install-recommends \
13
- libgl1 \
14
- libglib2.0-0 \
15
- libsm6 \
16
- libxext6 \
17
- libxrender1 \
18
- && rm -rf /var/lib/apt/lists/*
19
-
20
- # Create directories as root BEFORE switching user
21
- RUN mkdir -p /app/storage/uploads /app/storage/processed && \
22
- chown -R user:user /app
23
-
24
- # Switch to non-root user
25
- USER user
26
-
27
- # Copy requirements and install
28
- COPY --chown=user requirements.txt .
29
- RUN pip install --no-cache-dir --user -r requirements.txt
30
-
31
- # Copy the rest of the application
32
- COPY --chown=user . /app
33
-
34
- # Expose port 7860 (required by Hugging Face Spaces)
35
- EXPOSE 7860
36
-
37
- # Run the application
38
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
+ # Hugging Face Spaces Docker Configuration
2
+ FROM python:3.10-slim
3
+
4
+ # Create non-root user (required by HuggingFace)
5
+ RUN useradd -m -u 1000 user
6
+ ENV HOME=/home/user
7
+ ENV PATH="/home/user/.local/bin:$PATH"
8
+
9
+ WORKDIR /app
10
+
11
+ # Install system dependencies as root
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ libgl1 \
14
+ libglib2.0-0 \
15
+ libsm6 \
16
+ libxext6 \
17
+ libxrender1 \
18
+ && rm -rf /var/lib/apt/lists/*
19
+
20
+ # Create directories as root BEFORE switching user
21
+ RUN mkdir -p /app/storage/uploads /app/storage/processed && \
22
+ chown -R user:user /app
23
+
24
+ # Switch to non-root user
25
+ USER user
26
+
27
+ # Copy requirements and install
28
+ COPY --chown=user requirements.txt .
29
+ RUN pip install --no-cache-dir --user -r requirements.txt
30
+
31
+ # Copy the rest of the application
32
+ COPY --chown=user . /app
33
+
34
+ # Expose port 7860 (required by Hugging Face Spaces)
35
+ EXPOSE 7860
36
+
37
+ # Run the application
38
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,16 +1,21 @@
1
- ---
2
- title: Elephmind Api
3
- emoji: 🏃
4
- colorFrom: indigo
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- license: gemma
9
- short_description: ' IA imagerie medical '
10
- ---
11
-
12
- # ElephMind API v5.1
13
-
14
- Backend for medical image analysis.
15
-
16
- **Last Update**: 2026-01-31 17:00 - Fixed KeyError issues and disabled faulty CTR morphology engine.
 
 
 
 
 
 
1
+ ---
2
+ title: ElephMind Medical AI
3
+ emoji: 🏥
4
+ colorFrom: green
5
+ colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: true
9
+ license: apache-2.0
10
+ ---
11
+
12
+ # ElephMind - Diagnostic IA Médical
13
+
14
+ Application d'aide au diagnostic médical basée sur l'intelligence artificielle.
15
+
16
+ ## Fonctionnalités
17
+ - Analyse de radiographies thoraciques
18
+ - Analyse dermatologique
19
+ - Analyse histologique
20
+ - Analyse ophtalmologique
21
+ - Analyse orthopédique
database.py CHANGED
@@ -1,513 +1,464 @@
1
- import sqlite3
2
- import os
3
- import logging
4
- from typing import Optional, List, Dict, Any
5
- from enum import Enum
6
-
7
- class JobStatus(str, Enum):
8
- PENDING = "pending"
9
- PROCESSING = "processing"
10
- COMPLETED = "completed"
11
- FAILED = "failed"
12
-
13
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14
- # HUGGING FACE PERSISTENCE FIX: Use /data if available
15
- if os.path.exists('/data'):
16
- DB_NAME = '/data/elephmind.db'
17
- logging.info("Using PERSISTENT storage at /data/elephmind.db")
18
- else:
19
- DB_NAME = os.path.join(BASE_DIR, "elephmind.db")
20
- logging.info(f"Using LOCAL storage at {DB_NAME}")
21
-
22
- def get_db_connection():
23
- conn = sqlite3.connect(DB_NAME)
24
- conn.row_factory = sqlite3.Row
25
- return conn
26
-
27
- def init_db():
28
- conn = get_db_connection()
29
- c = conn.cursor()
30
-
31
- # Create Users Table
32
- c.execute('''
33
- CREATE TABLE IF NOT EXISTS users (
34
- id INTEGER PRIMARY KEY AUTOINCREMENT,
35
- username TEXT UNIQUE NOT NULL,
36
- hashed_password TEXT NOT NULL,
37
- email TEXT,
38
- security_question TEXT NOT NULL,
39
- security_answer TEXT NOT NULL,
40
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
41
- )
42
- ''')
43
-
44
- # Create Feedback Table
45
- c.execute('''
46
- CREATE TABLE IF NOT EXISTS feedback (
47
- id INTEGER PRIMARY KEY AUTOINCREMENT,
48
- username TEXT,
49
- rating INTEGER,
50
- comment TEXT,
51
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
52
- )
53
- ''')
54
-
55
- # Create Audit Log Table (RGPD Compliance)
56
- c.execute('''
57
- CREATE TABLE IF NOT EXISTS audit_log (
58
- id INTEGER PRIMARY KEY AUTOINCREMENT,
59
- username TEXT,
60
- action TEXT NOT NULL,
61
- resource TEXT,
62
- ip_address TEXT,
63
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
64
- )
65
- ''')
66
-
67
- # --- MIGRATIONS ---
68
- # Ensure security columns exist (backward compatibility)
69
- try:
70
- c.execute("ALTER TABLE users ADD COLUMN security_question TEXT DEFAULT 'Question?'")
71
- except sqlite3.OperationalError:
72
- pass # Column exists
73
-
74
- try:
75
- c.execute("ALTER TABLE users ADD COLUMN security_answer TEXT DEFAULT 'answer'")
76
- except sqlite3.OperationalError:
77
- pass # Column exists
78
- # ------------------
79
-
80
- # Create Patients Table
81
- c.execute('''
82
- CREATE TABLE IF NOT EXISTS patients (
83
- id INTEGER PRIMARY KEY AUTOINCREMENT,
84
- patient_id TEXT UNIQUE NOT NULL, -- e.g. PAT-2026-1234
85
- owner_username TEXT NOT NULL,
86
- first_name TEXT,
87
- last_name TEXT,
88
- birth_date TEXT,
89
- photo TEXT, -- Stores base64 or URL
90
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
91
- FOREIGN KEY(owner_username) REFERENCES users(username)
92
- )
93
- ''')
94
-
95
- # Create Jobs Table (PERSISTENCE)
96
- c.execute('''
97
- CREATE TABLE IF NOT EXISTS jobs (
98
- id TEXT PRIMARY KEY,
99
- status TEXT NOT NULL,
100
- result TEXT, -- JSON serialized
101
- error TEXT,
102
- created_at REAL,
103
- storage_path TEXT,
104
- username TEXT,
105
- file_type TEXT,
106
- FOREIGN KEY(username) REFERENCES users(username)
107
- )
108
- ''')
109
-
110
- conn.commit()
111
- conn.close()
112
- logging.info(f"Database {DB_NAME} initialized successfully.")
113
-
114
- # --- User Operations ---
115
-
116
- def create_user(user: Dict[str, Any]) -> bool:
117
- try:
118
- conn = get_db_connection()
119
- c = conn.cursor()
120
- c.execute('''
121
- INSERT INTO users (username, hashed_password, email, security_question, security_answer)
122
- VALUES (?, ?, ?, ?, ?)
123
- ''', (
124
- user['username'],
125
- user['hashed_password'],
126
- user.get('email', ''),
127
- user['security_question'],
128
- user['security_answer']
129
- ))
130
- conn.commit()
131
- return True
132
- except sqlite3.IntegrityError:
133
- return False
134
- except Exception as e:
135
- logging.error(f"Error creating user: {e}")
136
- return False
137
- finally:
138
- conn.close()
139
-
140
- def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
141
- conn = get_db_connection()
142
- c = conn.cursor()
143
- c.execute('SELECT * FROM users WHERE username = ?', (username,))
144
- row = c.fetchone()
145
- conn.close()
146
- if row:
147
- return dict(row)
148
- return None
149
-
150
- def update_password(username: str, new_hashed_password: str) -> bool:
151
- try:
152
- conn = get_db_connection()
153
- c = conn.cursor()
154
- c.execute('UPDATE users SET hashed_password = ? WHERE username = ?', (new_hashed_password, username))
155
- conn.commit()
156
- conn.close()
157
- return True
158
- except Exception as e:
159
- logging.error(f"Error updating password: {e}")
160
- return False
161
-
162
- # --- Feedback Operations ---
163
-
164
- def add_feedback(username: str, rating: int, comment: str):
165
- conn = get_db_connection()
166
- c = conn.cursor()
167
- c.execute('INSERT INTO feedback (username, rating, comment) VALUES (?, ?, ?)', (username, rating, comment))
168
- conn.commit()
169
- conn.close()
170
-
171
- # --- Audit Log Operations (RGPD Compliance) ---
172
-
173
- def log_audit(username: str, action: str, resource: str = None, ip_address: str = None):
174
- """Log user actions for RGPD compliance and security auditing."""
175
- try:
176
- conn = get_db_connection()
177
- c = conn.cursor()
178
- c.execute(
179
- 'INSERT INTO audit_log (username, action, resource, ip_address) VALUES (?, ?, ?, ?)',
180
- (username, action, resource, ip_address)
181
- )
182
- conn.commit()
183
- conn.close()
184
- except Exception as e:
185
- logging.error(f"Error logging audit: {e}")
186
-
187
- def get_user_audit_log(username: str, limit: int = 100) -> List[Dict[str, Any]]:
188
- """Get audit log for a specific user."""
189
- conn = get_db_connection()
190
- c = conn.cursor()
191
- c.execute(
192
- 'SELECT * FROM audit_log WHERE username = ? ORDER BY created_at DESC LIMIT ?',
193
- (username, limit)
194
- )
195
- rows = c.fetchall()
196
- conn.close()
197
- return [dict(row) for row in rows]
198
-
199
- # --- Analysis Registry (REAL DATA ONLY) ---
200
-
201
- def init_analysis_registry():
202
- """Create the analysis_registry table if it doesn't exist."""
203
- conn = get_db_connection()
204
- c = conn.cursor()
205
- c.execute('''
206
- CREATE TABLE IF NOT EXISTS analysis_registry (
207
- id INTEGER PRIMARY KEY AUTOINCREMENT,
208
- username TEXT NOT NULL,
209
- domain TEXT NOT NULL,
210
- top_diagnosis TEXT,
211
- confidence REAL,
212
- priority TEXT,
213
- computation_time_ms INTEGER,
214
- file_type TEXT,
215
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
216
- )
217
- ''')
218
- conn.commit()
219
- conn.close()
220
-
221
- def log_analysis(
222
- username: str,
223
- domain: str,
224
- top_diagnosis: str,
225
- confidence: float,
226
- priority: str,
227
- computation_time_ms: int,
228
- file_type: str
229
- ) -> bool:
230
- """Log a real analysis to the registry. NO FAKE DATA."""
231
- try:
232
- conn = get_db_connection()
233
- c = conn.cursor()
234
- c.execute('''
235
- INSERT INTO analysis_registry
236
- (username, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type)
237
- VALUES (?, ?, ?, ?, ?, ?, ?)
238
- ''', (username, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type))
239
- conn.commit()
240
- conn.close()
241
- return True
242
- except Exception as e:
243
- logging.error(f"Error logging analysis: {e}")
244
- return False
245
-
246
- def get_dashboard_stats(username: str) -> Dict[str, Any]:
247
- """Get real dashboard statistics for a user. Returns zeros if no data."""
248
- conn = get_db_connection()
249
- c = conn.cursor()
250
-
251
- # Total count
252
- c.execute('SELECT COUNT(*) FROM analysis_registry WHERE username = ?', (username,))
253
- total = c.fetchone()[0]
254
-
255
- # By domain
256
- c.execute('''
257
- SELECT domain, COUNT(*) as count
258
- FROM analysis_registry
259
- WHERE username = ?
260
- GROUP BY domain
261
- ''', (username,))
262
- by_domain = {row['domain']: row['count'] for row in c.fetchall()}
263
-
264
- # By priority
265
- c.execute('''
266
- SELECT priority, COUNT(*) as count
267
- FROM analysis_registry
268
- WHERE username = ?
269
- GROUP BY priority
270
- ''', (username,))
271
- by_priority = {row['priority']: row['count'] for row in c.fetchall()}
272
-
273
- # Average computation time
274
- c.execute('''
275
- SELECT AVG(computation_time_ms)
276
- FROM analysis_registry
277
- WHERE username = ?
278
- ''', (username,))
279
- avg_time = c.fetchone()[0] or 0
280
-
281
- conn.close()
282
-
283
- return {
284
- "total_analyses": total,
285
- "by_domain": by_domain,
286
- "by_priority": by_priority,
287
- "avg_computation_time_ms": round(avg_time, 0)
288
- }
289
-
290
- def get_recent_analyses(username: str, limit: int = 10) -> List[Dict[str, Any]]:
291
- """Get recent real analyses for a user. Returns empty list if none."""
292
- conn = get_db_connection()
293
- c = conn.cursor()
294
- c.execute('''
295
- SELECT id, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type, created_at
296
- FROM analysis_registry
297
- WHERE username = ?
298
- ORDER BY created_at DESC
299
- LIMIT ?
300
- ''', (username, limit))
301
- rows = c.fetchall()
302
- conn.close()
303
- return [dict(row) for row in rows]
304
-
305
- # --- Patient Operations (New for Migration) ---
306
-
307
- def create_patient(
308
- owner_username: str,
309
- patient_id: str,
310
- first_name: str,
311
- last_name: str,
312
- birth_date: str,
313
- photo: str
314
- ) -> Optional[int]:
315
- """Create a new patient record."""
316
- try:
317
- conn = get_db_connection()
318
- c = conn.cursor()
319
- c.execute('''
320
- INSERT INTO patients (owner_username, patient_id, first_name, last_name, birth_date, photo)
321
- VALUES (?, ?, ?, ?, ?, ?)
322
- ''', (owner_username, patient_id, first_name, last_name, birth_date, photo))
323
- patient_id_db = c.lastrowid
324
- conn.commit()
325
- conn.close()
326
- return patient_id_db
327
- except Exception as e:
328
- logging.error(f"Error creating patient: {e}")
329
- return None
330
-
331
- def get_patients_by_user(username: str) -> List[Dict[str, Any]]:
332
- """Get all patients belonging to a user."""
333
- conn = get_db_connection()
334
- c = conn.cursor()
335
- c.execute('SELECT * FROM patients WHERE owner_username = ? ORDER BY created_at DESC', (username,))
336
- rows = c.fetchall()
337
- conn.close()
338
- return [dict(row) for row in rows]
339
-
340
- def delete_patient(username: str, patient_db_id: int) -> bool:
341
- """Delete a patient record if owned by user."""
342
- try:
343
- conn = get_db_connection()
344
- c = conn.cursor()
345
- c.execute('DELETE FROM patients WHERE id = ? AND owner_username = ?', (patient_db_id, username))
346
- count = c.rowcount
347
- conn.commit()
348
- conn.close()
349
- return count > 0
350
- except Exception as e:
351
- logging.error(f"Error deleting patient: {e}")
352
- return False
353
-
354
- def update_patient(username: str, patient_db_id: int, updates: Dict[str, Any]) -> bool:
355
- """Update patient fields."""
356
- try:
357
- conn = get_db_connection()
358
- c = conn.cursor()
359
-
360
- # Build query dynamically
361
- fields = []
362
- values = []
363
- for k, v in updates.items():
364
- if k in ['first_name', 'last_name', 'birth_date', 'photo']:
365
- fields.append(f"{k} = ?")
366
- values.append(v)
367
-
368
- if not fields:
369
- return False
370
-
371
- values.extend([patient_db_id, username])
372
- query = f"UPDATE patients SET {', '.join(fields)} WHERE id = ? AND owner_username = ?"
373
-
374
-
375
- c.execute(query, values)
376
- count = c.rowcount
377
- conn.commit()
378
- conn.close()
379
- return count > 0
380
- except Exception as e:
381
- logging.error(f"Error updating patient: {e}")
382
- return False
383
-
384
- # --- Job Operations (Persistence) ---
385
-
386
- import json
387
-
388
- def create_job(job_data: Dict[str, Any]):
389
- """Create a new job record."""
390
- try:
391
- conn = get_db_connection()
392
- c = conn.cursor()
393
- c.execute('''
394
- INSERT INTO jobs (id, status, result, error, created_at, storage_path, username, file_type)
395
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
396
- ''', (
397
- job_data['id'],
398
- job_data.get('status', 'pending'),
399
- json.dumps(job_data.get('result')) if job_data.get('result') else None,
400
- job_data.get('error'),
401
- job_data['created_at'],
402
- job_data.get('storage_path'),
403
- job_data.get('username'),
404
- job_data.get('file_type')
405
- ))
406
- conn.commit()
407
- conn.close()
408
- return True
409
- except Exception as e:
410
- logging.error(f"Error creating job: {e}")
411
- return False
412
-
413
- def get_job(job_id: str, username: Optional[str] = None) -> Optional[Dict[str, Any]]:
414
- """Retrieve job by ID, optionally enforcing ownership via SQL."""
415
- conn = get_db_connection()
416
- c = conn.cursor()
417
-
418
- if username:
419
- c.execute('SELECT * FROM jobs WHERE id = ? AND username = ?', (job_id, username))
420
- else:
421
- c.execute('SELECT * FROM jobs WHERE id = ?', (job_id,))
422
-
423
- row = c.fetchone()
424
- conn.close()
425
-
426
- if row:
427
- job = dict(row)
428
- if job['result']:
429
- try:
430
- job['result'] = json.loads(job['result'])
431
- except:
432
- job['result'] = None
433
- return job
434
- return None
435
-
436
- def update_job_status(job_id: str, status: str, result: Optional[Dict] = None, error: Optional[str] = None):
437
- """Update job status and result."""
438
- try:
439
- conn = get_db_connection()
440
- c = conn.cursor()
441
-
442
- updates = ["status = ?"]
443
- params = [status]
444
-
445
- if result is not None:
446
- updates.append("result = ?")
447
- params.append(json.dumps(result))
448
-
449
- if error is not None:
450
- updates.append("error = ?")
451
- params.append(error)
452
-
453
- params.append(job_id)
454
-
455
- query = f"UPDATE jobs SET {', '.join(updates)} WHERE id = ?"
456
- c.execute(query, params)
457
- conn.commit()
458
- conn.close()
459
- return True
460
- except Exception as e:
461
- logging.error(f"Error updating job: {e}")
462
- return False
463
-
464
-
465
-
466
- def get_latest_job(username: str) -> Optional[Dict[str, Any]]:
467
- """Retrieve the most recent job for a user."""
468
- conn = get_db_connection()
469
- c = conn.cursor()
470
- c.execute('''
471
- SELECT * FROM jobs
472
- WHERE username = ?
473
- ORDER BY created_at DESC
474
- LIMIT 1
475
- ''', (username,))
476
- row = c.fetchone()
477
- conn.close()
478
-
479
- if row:
480
- job = dict(row)
481
- if job['result']:
482
- try:
483
- job['result'] = json.loads(job['result'])
484
- except:
485
- job['result'] = None
486
- return job
487
- return None
488
-
489
- def get_active_job_by_image(username: str, image_id: str) -> Optional[Dict[str, Any]]:
490
- """
491
- Retrieve the most recent job for a specific image and user.
492
- Used for Idempotence (Strict Lifecycle).
493
- """
494
- conn = get_db_connection()
495
- c = conn.cursor()
496
- c.execute('''
497
- SELECT * FROM jobs
498
- WHERE username = ? AND storage_path = ?
499
- ORDER BY created_at DESC
500
- LIMIT 1
501
- ''', (username, image_id))
502
- row = c.fetchone()
503
- conn.close()
504
-
505
- if row:
506
- job = dict(row)
507
- if job['result']:
508
- try:
509
- job['result'] = json.loads(job['result'])
510
- except:
511
- job['result'] = None
512
- return job
513
- return None
 
1
+ import sqlite3
2
+ import os
3
+ import logging
4
+ from typing import Optional, List, Dict, Any
5
+ from enum import Enum
6
+
7
+ class JobStatus(str, Enum):
8
+ PENDING = "pending"
9
+ PROCESSING = "processing"
10
+ COMPLETED = "completed"
11
+ FAILED = "failed"
12
+
13
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14
+ # HUGGING FACE PERSISTENCE FIX: Use /data if available
15
+ if os.path.exists('/data'):
16
+ DB_NAME = '/data/elephmind.db'
17
+ logging.info("Using PERSISTENT storage at /data/elephmind.db")
18
+ else:
19
+ DB_NAME = os.path.join(BASE_DIR, "elephmind.db")
20
+ logging.info(f"Using LOCAL storage at {DB_NAME}")
21
+
22
+ def get_db_connection():
23
+ conn = sqlite3.connect(DB_NAME)
24
+ conn.row_factory = sqlite3.Row
25
+ return conn
26
+
27
+ def init_db():
28
+ conn = get_db_connection()
29
+ c = conn.cursor()
30
+
31
+ # Create Users Table
32
+ c.execute('''
33
+ CREATE TABLE IF NOT EXISTS users (
34
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
35
+ username TEXT UNIQUE NOT NULL,
36
+ hashed_password TEXT NOT NULL,
37
+ email TEXT,
38
+ security_question TEXT NOT NULL,
39
+ security_answer TEXT NOT NULL,
40
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
41
+ )
42
+ ''')
43
+
44
+ # Create Feedback Table
45
+ c.execute('''
46
+ CREATE TABLE IF NOT EXISTS feedback (
47
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
48
+ username TEXT,
49
+ rating INTEGER,
50
+ comment TEXT,
51
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
52
+ )
53
+ ''')
54
+
55
+ # Create Audit Log Table (RGPD Compliance)
56
+ c.execute('''
57
+ CREATE TABLE IF NOT EXISTS audit_log (
58
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
59
+ username TEXT,
60
+ action TEXT NOT NULL,
61
+ resource TEXT,
62
+ ip_address TEXT,
63
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
64
+ )
65
+ ''')
66
+
67
+ # --- MIGRATIONS ---
68
+ # Ensure security columns exist (backward compatibility)
69
+ try:
70
+ c.execute("ALTER TABLE users ADD COLUMN security_question TEXT DEFAULT 'Question?'")
71
+ except sqlite3.OperationalError:
72
+ pass # Column exists
73
+
74
+ try:
75
+ c.execute("ALTER TABLE users ADD COLUMN security_answer TEXT DEFAULT 'answer'")
76
+ except sqlite3.OperationalError:
77
+ pass # Column exists
78
+ # ------------------
79
+
80
+ # Create Patients Table
81
+ c.execute('''
82
+ CREATE TABLE IF NOT EXISTS patients (
83
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
84
+ patient_id TEXT UNIQUE NOT NULL, -- e.g. PAT-2026-1234
85
+ owner_username TEXT NOT NULL,
86
+ first_name TEXT,
87
+ last_name TEXT,
88
+ birth_date TEXT,
89
+ photo TEXT, -- Stores base64 or URL
90
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
91
+ FOREIGN KEY(owner_username) REFERENCES users(username)
92
+ )
93
+ ''')
94
+
95
+ # Create Jobs Table (PERSISTENCE)
96
+ c.execute('''
97
+ CREATE TABLE IF NOT EXISTS jobs (
98
+ id TEXT PRIMARY KEY,
99
+ status TEXT NOT NULL,
100
+ result TEXT, -- JSON serialized
101
+ error TEXT,
102
+ created_at REAL,
103
+ storage_path TEXT,
104
+ username TEXT,
105
+ file_type TEXT,
106
+ FOREIGN KEY(username) REFERENCES users(username)
107
+ )
108
+ ''')
109
+
110
+ conn.commit()
111
+ conn.close()
112
+ logging.info(f"Database {DB_NAME} initialized successfully.")
113
+
114
+ # --- User Operations ---
115
+
116
+ def create_user(user: Dict[str, Any]) -> bool:
117
+ try:
118
+ conn = get_db_connection()
119
+ c = conn.cursor()
120
+ c.execute('''
121
+ INSERT INTO users (username, hashed_password, email, security_question, security_answer)
122
+ VALUES (?, ?, ?, ?, ?)
123
+ ''', (
124
+ user['username'],
125
+ user['hashed_password'],
126
+ user.get('email', ''),
127
+ user['security_question'],
128
+ user['security_answer']
129
+ ))
130
+ conn.commit()
131
+ return True
132
+ except sqlite3.IntegrityError:
133
+ return False
134
+ except Exception as e:
135
+ logging.error(f"Error creating user: {e}")
136
+ return False
137
+ finally:
138
+ conn.close()
139
+
140
+ def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
141
+ conn = get_db_connection()
142
+ c = conn.cursor()
143
+ c.execute('SELECT * FROM users WHERE username = ?', (username,))
144
+ row = c.fetchone()
145
+ conn.close()
146
+ if row:
147
+ return dict(row)
148
+ return None
149
+
150
+ def update_password(username: str, new_hashed_password: str) -> bool:
151
+ try:
152
+ conn = get_db_connection()
153
+ c = conn.cursor()
154
+ c.execute('UPDATE users SET hashed_password = ? WHERE username = ?', (new_hashed_password, username))
155
+ conn.commit()
156
+ conn.close()
157
+ return True
158
+ except Exception as e:
159
+ logging.error(f"Error updating password: {e}")
160
+ return False
161
+
162
+ # --- Feedback Operations ---
163
+
164
+ def add_feedback(username: str, rating: int, comment: str):
165
+ conn = get_db_connection()
166
+ c = conn.cursor()
167
+ c.execute('INSERT INTO feedback (username, rating, comment) VALUES (?, ?, ?)', (username, rating, comment))
168
+ conn.commit()
169
+ conn.close()
170
+
171
+ # --- Audit Log Operations (RGPD Compliance) ---
172
+
173
+ def log_audit(username: str, action: str, resource: str = None, ip_address: str = None):
174
+ """Log user actions for RGPD compliance and security auditing."""
175
+ try:
176
+ conn = get_db_connection()
177
+ c = conn.cursor()
178
+ c.execute(
179
+ 'INSERT INTO audit_log (username, action, resource, ip_address) VALUES (?, ?, ?, ?)',
180
+ (username, action, resource, ip_address)
181
+ )
182
+ conn.commit()
183
+ conn.close()
184
+ except Exception as e:
185
+ logging.error(f"Error logging audit: {e}")
186
+
187
+ def get_user_audit_log(username: str, limit: int = 100) -> List[Dict[str, Any]]:
188
+ """Get audit log for a specific user."""
189
+ conn = get_db_connection()
190
+ c = conn.cursor()
191
+ c.execute(
192
+ 'SELECT * FROM audit_log WHERE username = ? ORDER BY created_at DESC LIMIT ?',
193
+ (username, limit)
194
+ )
195
+ rows = c.fetchall()
196
+ conn.close()
197
+ return [dict(row) for row in rows]
198
+
199
+ # --- Analysis Registry (REAL DATA ONLY) ---
200
+
201
+ def init_analysis_registry():
202
+ """Create the analysis_registry table if it doesn't exist."""
203
+ conn = get_db_connection()
204
+ c = conn.cursor()
205
+ c.execute('''
206
+ CREATE TABLE IF NOT EXISTS analysis_registry (
207
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
208
+ username TEXT NOT NULL,
209
+ domain TEXT NOT NULL,
210
+ top_diagnosis TEXT,
211
+ confidence REAL,
212
+ priority TEXT,
213
+ computation_time_ms INTEGER,
214
+ file_type TEXT,
215
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
216
+ )
217
+ ''')
218
+ conn.commit()
219
+ conn.close()
220
+
221
+ def log_analysis(
222
+ username: str,
223
+ domain: str,
224
+ top_diagnosis: str,
225
+ confidence: float,
226
+ priority: str,
227
+ computation_time_ms: int,
228
+ file_type: str
229
+ ) -> bool:
230
+ """Log a real analysis to the registry. NO FAKE DATA."""
231
+ try:
232
+ conn = get_db_connection()
233
+ c = conn.cursor()
234
+ c.execute('''
235
+ INSERT INTO analysis_registry
236
+ (username, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type)
237
+ VALUES (?, ?, ?, ?, ?, ?, ?)
238
+ ''', (username, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type))
239
+ conn.commit()
240
+ conn.close()
241
+ return True
242
+ except Exception as e:
243
+ logging.error(f"Error logging analysis: {e}")
244
+ return False
245
+
246
+ def get_dashboard_stats(username: str) -> Dict[str, Any]:
247
+ """Get real dashboard statistics for a user. Returns zeros if no data."""
248
+ conn = get_db_connection()
249
+ c = conn.cursor()
250
+
251
+ # Total count
252
+ c.execute('SELECT COUNT(*) FROM analysis_registry WHERE username = ?', (username,))
253
+ total = c.fetchone()[0]
254
+
255
+ # By domain
256
+ c.execute('''
257
+ SELECT domain, COUNT(*) as count
258
+ FROM analysis_registry
259
+ WHERE username = ?
260
+ GROUP BY domain
261
+ ''', (username,))
262
+ by_domain = {row['domain']: row['count'] for row in c.fetchall()}
263
+
264
+ # By priority
265
+ c.execute('''
266
+ SELECT priority, COUNT(*) as count
267
+ FROM analysis_registry
268
+ WHERE username = ?
269
+ GROUP BY priority
270
+ ''', (username,))
271
+ by_priority = {row['priority']: row['count'] for row in c.fetchall()}
272
+
273
+ # Average computation time
274
+ c.execute('''
275
+ SELECT AVG(computation_time_ms)
276
+ FROM analysis_registry
277
+ WHERE username = ?
278
+ ''', (username,))
279
+ avg_time = c.fetchone()[0] or 0
280
+
281
+ conn.close()
282
+
283
+ return {
284
+ "total_analyses": total,
285
+ "by_domain": by_domain,
286
+ "by_priority": by_priority,
287
+ "avg_computation_time_ms": round(avg_time, 0)
288
+ }
289
+
290
+ def get_recent_analyses(username: str, limit: int = 10) -> List[Dict[str, Any]]:
291
+ """Get recent real analyses for a user. Returns empty list if none."""
292
+ conn = get_db_connection()
293
+ c = conn.cursor()
294
+ c.execute('''
295
+ SELECT id, domain, top_diagnosis, confidence, priority, computation_time_ms, file_type, created_at
296
+ FROM analysis_registry
297
+ WHERE username = ?
298
+ ORDER BY created_at DESC
299
+ LIMIT ?
300
+ ''', (username, limit))
301
+ rows = c.fetchall()
302
+ conn.close()
303
+ return [dict(row) for row in rows]
304
+
305
+ # --- Patient Operations (New for Migration) ---
306
+
307
+ def create_patient(
308
+ owner_username: str,
309
+ patient_id: str,
310
+ first_name: str,
311
+ last_name: str,
312
+ birth_date: str,
313
+ photo: str
314
+ ) -> Optional[int]:
315
+ """Create a new patient record."""
316
+ try:
317
+ conn = get_db_connection()
318
+ c = conn.cursor()
319
+ c.execute('''
320
+ INSERT INTO patients (owner_username, patient_id, first_name, last_name, birth_date, photo)
321
+ VALUES (?, ?, ?, ?, ?, ?)
322
+ ''', (owner_username, patient_id, first_name, last_name, birth_date, photo))
323
+ patient_id_db = c.lastrowid
324
+ conn.commit()
325
+ conn.close()
326
+ return patient_id_db
327
+ except Exception as e:
328
+ logging.error(f"Error creating patient: {e}")
329
+ return None
330
+
331
+ def get_patients_by_user(username: str) -> List[Dict[str, Any]]:
332
+ """Get all patients belonging to a user."""
333
+ conn = get_db_connection()
334
+ c = conn.cursor()
335
+ c.execute('SELECT * FROM patients WHERE owner_username = ? ORDER BY created_at DESC', (username,))
336
+ rows = c.fetchall()
337
+ conn.close()
338
+ return [dict(row) for row in rows]
339
+
340
+ def delete_patient(username: str, patient_db_id: int) -> bool:
341
+ """Delete a patient record if owned by user."""
342
+ try:
343
+ conn = get_db_connection()
344
+ c = conn.cursor()
345
+ c.execute('DELETE FROM patients WHERE id = ? AND owner_username = ?', (patient_db_id, username))
346
+ count = c.rowcount
347
+ conn.commit()
348
+ conn.close()
349
+ return count > 0
350
+ except Exception as e:
351
+ logging.error(f"Error deleting patient: {e}")
352
+ return False
353
+
354
+ def update_patient(username: str, patient_db_id: int, updates: Dict[str, Any]) -> bool:
355
+ """Update patient fields."""
356
+ try:
357
+ conn = get_db_connection()
358
+ c = conn.cursor()
359
+
360
+ # Build query dynamically
361
+ fields = []
362
+ values = []
363
+ for k, v in updates.items():
364
+ if k in ['first_name', 'last_name', 'birth_date', 'photo']:
365
+ fields.append(f"{k} = ?")
366
+ values.append(v)
367
+
368
+ if not fields:
369
+ return False
370
+
371
+ values.extend([patient_db_id, username])
372
+ query = f"UPDATE patients SET {', '.join(fields)} WHERE id = ? AND owner_username = ?"
373
+
374
+
375
+ c.execute(query, values)
376
+ count = c.rowcount
377
+ conn.commit()
378
+ conn.close()
379
+ return count > 0
380
+ except Exception as e:
381
+ logging.error(f"Error updating patient: {e}")
382
+ return False
383
+
384
+ # --- Job Operations (Persistence) ---
385
+
386
+ import json
387
+
388
+ def create_job(job_data: Dict[str, Any]):
389
+ """Create a new job record."""
390
+ try:
391
+ conn = get_db_connection()
392
+ c = conn.cursor()
393
+ c.execute('''
394
+ INSERT INTO jobs (id, status, result, error, created_at, storage_path, username, file_type)
395
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
396
+ ''', (
397
+ job_data['id'],
398
+ job_data.get('status', 'pending'),
399
+ json.dumps(job_data.get('result')) if job_data.get('result') else None,
400
+ job_data.get('error'),
401
+ job_data['created_at'],
402
+ job_data.get('storage_path'),
403
+ job_data.get('username'),
404
+ job_data.get('file_type')
405
+ ))
406
+ conn.commit()
407
+ conn.close()
408
+ return True
409
+ except Exception as e:
410
+ logging.error(f"Error creating job: {e}")
411
+ return False
412
+
413
+ def get_job(job_id: str, username: Optional[str] = None) -> Optional[Dict[str, Any]]:
414
+ """Retrieve job by ID, optionally enforcing ownership via SQL."""
415
+ conn = get_db_connection()
416
+ c = conn.cursor()
417
+
418
+ if username:
419
+ c.execute('SELECT * FROM jobs WHERE id = ? AND username = ?', (job_id, username))
420
+ else:
421
+ c.execute('SELECT * FROM jobs WHERE id = ?', (job_id,))
422
+
423
+ row = c.fetchone()
424
+ conn.close()
425
+
426
+ if row:
427
+ job = dict(row)
428
+ if job['result']:
429
+ try:
430
+ job['result'] = json.loads(job['result'])
431
+ except:
432
+ job['result'] = None
433
+ return job
434
+ return None
435
+
436
+ def update_job_status(job_id: str, status: str, result: Optional[Dict] = None, error: Optional[str] = None):
437
+ """Update job status and result."""
438
+ try:
439
+ conn = get_db_connection()
440
+ c = conn.cursor()
441
+
442
+ updates = ["status = ?"]
443
+ params = [status]
444
+
445
+ if result is not None:
446
+ updates.append("result = ?")
447
+ params.append(json.dumps(result))
448
+
449
+ if error is not None:
450
+ updates.append("error = ?")
451
+ params.append(error)
452
+
453
+ params.append(job_id)
454
+
455
+ query = f"UPDATE jobs SET {', '.join(updates)} WHERE id = ?"
456
+ c.execute(query, params)
457
+ conn.commit()
458
+ conn.close()
459
+ return True
460
+ except Exception as e:
461
+ logging.error(f"Error updating job: {e}")
462
+ return False
463
+
464
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dicom_processor.py CHANGED
@@ -1,167 +1,167 @@
1
- import pydicom
2
- import logging
3
- import hashlib
4
- from typing import Tuple, Dict, Any, Optional
5
- from pathlib import Path
6
- import os
7
- import io
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- # Mandatory DICOM Tags for Medical Validity
12
- REQUIRED_TAGS = [
13
- 'PatientID',
14
- 'StudyInstanceUID',
15
- 'SeriesInstanceUID',
16
- 'Modality',
17
- 'PixelSpacing', # Crucial for measurements
18
- ]
19
-
20
- # Tags to Anonymize (PHI)
21
- PHI_TAGS = [
22
- 'PatientName',
23
- 'PatientBirthDate',
24
- 'PatientAddress',
25
- 'InstitutionName',
26
- 'ReferringPhysicianName'
27
- ]
28
-
29
- def validate_dicom(file_bytes: bytes) -> pydicom.dataset.FileDataset:
30
- """
31
- Strict validation of DICOM file.
32
- Raises ValueError if invalid.
33
- """
34
- try:
35
- # 1. Parse without loading pixel data first (speed)
36
- ds = pydicom.dcmread(io.BytesIO(file_bytes), stop_before_pixels=False)
37
- except Exception as e:
38
- raise ValueError(f"Invalid DICOM format: {str(e)}")
39
-
40
- # 2. Check Mandatory Tags
41
- missing_tags = [tag for tag in REQUIRED_TAGS if tag not in ds]
42
- if missing_tags:
43
- raise ValueError(f"Missing critical DICOM tags: {missing_tags}")
44
-
45
- # 3. Check Pixel Data presence
46
- if 'PixelData' not in ds:
47
- raise ValueError("DICOM file has no image data (PixelData missing).")
48
-
49
- return ds
50
-
51
- def anonymize_dicom(ds: pydicom.dataset.FileDataset) -> pydicom.dataset.FileDataset:
52
- """
53
- Remove PHI from dataset.
54
- Returns modified dataset.
55
- """
56
- # Hash PatientID to keep linkable anonymous ID
57
- original_id = str(ds.get('PatientID', 'Unknown'))
58
- hashed_id = hashlib.sha256(original_id.encode()).hexdigest()[:16].upper()
59
-
60
- ds.PatientID = f"ANON-{hashed_id}"
61
-
62
- # Wipe other fields
63
- for tag in PHI_TAGS:
64
- if tag in ds:
65
- if 'Date' in tag: # VR DA requires YYYYMMDD
66
- ds.data_element(tag).value = "19010101"
67
- else:
68
- ds.data_element(tag).value = "ANONYMIZED"
69
-
70
- return ds
71
-
72
- def process_dicom_upload(file_bytes: bytes, username: str) -> Tuple[bytes, Dict[str, Any]]:
73
- """
74
- Main Gateway Function: Validate -> Anonymize -> Return Bytes & Metadata
75
- """
76
- # 1. Validate
77
- try:
78
- ds = validate_dicom(file_bytes)
79
- except Exception as e:
80
- logger.error(f"DICOM Validation Failed: {e}")
81
- raise ValueError(f"DICOM Rejected: {e}")
82
-
83
- # 2. Anonymize
84
- ds = anonymize_dicom(ds)
85
-
86
- # 3. Extract safe metadata
87
- metadata = {
88
- "modality": ds.get("Modality", "Unknown"),
89
- "body_part": ds.get("BodyPartExamined", "Unknown"),
90
- "study_uid": str(ds.get("StudyInstanceUID", "")),
91
- "pixel_spacing": ds.get("PixelSpacing", [1.0, 1.0]),
92
- "original_filename_hint": "dicom_file.dcm"
93
- }
94
-
95
- # 4. Convert back to bytes for storage
96
- with io.BytesIO() as buffer:
97
- ds.save_as(buffer)
98
- safe_bytes = buffer.getvalue()
99
-
100
- return safe_bytes, metadata
101
-
102
- def convert_dicom_to_image(ds: pydicom.dataset.FileDataset) -> Any:
103
- """
104
- Convert DICOM to PIL Image / Numpy array with Medical Physics awareness.
105
- 1. Check RAS Orientation (Basic Validation).
106
- 2. Apply Hounsfield Units (CT) or Intensity Normalization (MRI/XRay).
107
- 3. Windowing (Lung/Bone/Soft Tissue).
108
- """
109
- import numpy as np
110
- from PIL import Image
111
-
112
- try:
113
- # 1. Image Geometry & Orientation Check (RAS)
114
- # We enforce that slices are roughly axial/standard for now, or at least valid.
115
- orientation = ds.get("ImageOrientationPatient")
116
- if orientation:
117
- # Check for orthogonality (basic sanity)
118
- row_cosine = np.array(orientation[:3])
119
- col_cosine = np.array(orientation[3:])
120
- if np.abs(np.dot(row_cosine, col_cosine)) > 1e-3:
121
- logger.warning("DICOM Orientation vectors are not orthogonal. Image might be skewed.")
122
-
123
- # 2. Extract Raw Pixels
124
- pixel_array = ds.pixel_array.astype(float)
125
-
126
- # 3. Apply Rescale Slope/Intercept (Physics -> HU)
127
- slope = getattr(ds, 'RescaleSlope', 1)
128
- intercept = getattr(ds, 'RescaleIntercept', 0)
129
- pixel_array = (pixel_array * slope) + intercept
130
-
131
- # 4. Modality-Specific Normalization
132
- modality = ds.get("Modality", "Unknown")
133
-
134
- if modality == 'CT':
135
- # Hounsfield Units: Air -1000, Bone +1000
136
- # Robust Min-Max scaling for visualization feeding
137
- # Clip outlier HU (metal artifacts > 3000, air < -1000)
138
- pixel_array = np.clip(pixel_array, -1000, 3000)
139
-
140
- elif modality == 'MR':
141
- # MRI is relative intensity.
142
- # Simple 1-99 percentile clipping removes spikes.
143
- p1, p99 = np.percentile(pixel_array, [1, 99])
144
- pixel_array = np.clip(pixel_array, p1, p99)
145
-
146
- # 5. Normalization to 0-255 (Display Space)
147
- pixel_min = np.min(pixel_array)
148
- pixel_max = np.max(pixel_array)
149
-
150
- if pixel_max - pixel_min != 0:
151
- pixel_array = ((pixel_array - pixel_min) / (pixel_max - pixel_min)) * 255.0
152
- else:
153
- pixel_array = np.zeros_like(pixel_array)
154
-
155
- pixel_array = pixel_array.astype(np.uint8)
156
-
157
- # 6. Color Space
158
- if len(pixel_array.shape) == 2:
159
- image = Image.fromarray(pixel_array).convert("RGB")
160
- else:
161
- image = Image.fromarray(pixel_array)
162
-
163
- return image
164
-
165
- except Exception as e:
166
- logger.error(f"DICOM Conversion Error: {e}")
167
- raise ValueError(f"Could not convert DICOM to image: {e}")
 
1
+ import pydicom
2
+ import logging
3
+ import hashlib
4
+ from typing import Tuple, Dict, Any, Optional
5
+ from pathlib import Path
6
+ import os
7
+ import io
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Mandatory DICOM Tags for Medical Validity
12
+ REQUIRED_TAGS = [
13
+ 'PatientID',
14
+ 'StudyInstanceUID',
15
+ 'SeriesInstanceUID',
16
+ 'Modality',
17
+ 'PixelSpacing', # Crucial for measurements
18
+ ]
19
+
20
+ # Tags to Anonymize (PHI)
21
+ PHI_TAGS = [
22
+ 'PatientName',
23
+ 'PatientBirthDate',
24
+ 'PatientAddress',
25
+ 'InstitutionName',
26
+ 'ReferringPhysicianName'
27
+ ]
28
+
29
+ def validate_dicom(file_bytes: bytes) -> pydicom.dataset.FileDataset:
30
+ """
31
+ Strict validation of DICOM file.
32
+ Raises ValueError if invalid.
33
+ """
34
+ try:
35
+ # 1. Parse without loading pixel data first (speed)
36
+ ds = pydicom.dcmread(io.BytesIO(file_bytes), stop_before_pixels=False)
37
+ except Exception as e:
38
+ raise ValueError(f"Invalid DICOM format: {str(e)}")
39
+
40
+ # 2. Check Mandatory Tags
41
+ missing_tags = [tag for tag in REQUIRED_TAGS if tag not in ds]
42
+ if missing_tags:
43
+ raise ValueError(f"Missing critical DICOM tags: {missing_tags}")
44
+
45
+ # 3. Check Pixel Data presence
46
+ if 'PixelData' not in ds:
47
+ raise ValueError("DICOM file has no image data (PixelData missing).")
48
+
49
+ return ds
50
+
51
+ def anonymize_dicom(ds: pydicom.dataset.FileDataset) -> pydicom.dataset.FileDataset:
52
+ """
53
+ Remove PHI from dataset.
54
+ Returns modified dataset.
55
+ """
56
+ # Hash PatientID to keep linkable anonymous ID
57
+ original_id = str(ds.get('PatientID', 'Unknown'))
58
+ hashed_id = hashlib.sha256(original_id.encode()).hexdigest()[:16].upper()
59
+
60
+ ds.PatientID = f"ANON-{hashed_id}"
61
+
62
+ # Wipe other fields
63
+ for tag in PHI_TAGS:
64
+ if tag in ds:
65
+ if 'Date' in tag: # VR DA requires YYYYMMDD
66
+ ds.data_element(tag).value = "19010101"
67
+ else:
68
+ ds.data_element(tag).value = "ANONYMIZED"
69
+
70
+ return ds
71
+
72
+ def process_dicom_upload(file_bytes: bytes, username: str) -> Tuple[bytes, Dict[str, Any]]:
73
+ """
74
+ Main Gateway Function: Validate -> Anonymize -> Return Bytes & Metadata
75
+ """
76
+ # 1. Validate
77
+ try:
78
+ ds = validate_dicom(file_bytes)
79
+ except Exception as e:
80
+ logger.error(f"DICOM Validation Failed: {e}")
81
+ raise ValueError(f"DICOM Rejected: {e}")
82
+
83
+ # 2. Anonymize
84
+ ds = anonymize_dicom(ds)
85
+
86
+ # 3. Extract safe metadata
87
+ metadata = {
88
+ "modality": ds.get("Modality", "Unknown"),
89
+ "body_part": ds.get("BodyPartExamined", "Unknown"),
90
+ "study_uid": str(ds.get("StudyInstanceUID", "")),
91
+ "pixel_spacing": ds.get("PixelSpacing", [1.0, 1.0]),
92
+ "original_filename_hint": "dicom_file.dcm"
93
+ }
94
+
95
+ # 4. Convert back to bytes for storage
96
+ with io.BytesIO() as buffer:
97
+ ds.save_as(buffer)
98
+ safe_bytes = buffer.getvalue()
99
+
100
+ return safe_bytes, metadata
101
+
102
+ def convert_dicom_to_image(ds: pydicom.dataset.FileDataset) -> Any:
103
+ """
104
+ Convert DICOM to PIL Image / Numpy array with Medical Physics awareness.
105
+ 1. Check RAS Orientation (Basic Validation).
106
+ 2. Apply Hounsfield Units (CT) or Intensity Normalization (MRI/XRay).
107
+ 3. Windowing (Lung/Bone/Soft Tissue).
108
+ """
109
+ import numpy as np
110
+ from PIL import Image
111
+
112
+ try:
113
+ # 1. Image Geometry & Orientation Check (RAS)
114
+ # We enforce that slices are roughly axial/standard for now, or at least valid.
115
+ orientation = ds.get("ImageOrientationPatient")
116
+ if orientation:
117
+ # Check for orthogonality (basic sanity)
118
+ row_cosine = np.array(orientation[:3])
119
+ col_cosine = np.array(orientation[3:])
120
+ if np.abs(np.dot(row_cosine, col_cosine)) > 1e-3:
121
+ logger.warning("DICOM Orientation vectors are not orthogonal. Image might be skewed.")
122
+
123
+ # 2. Extract Raw Pixels
124
+ pixel_array = ds.pixel_array.astype(float)
125
+
126
+ # 3. Apply Rescale Slope/Intercept (Physics -> HU)
127
+ slope = getattr(ds, 'RescaleSlope', 1)
128
+ intercept = getattr(ds, 'RescaleIntercept', 0)
129
+ pixel_array = (pixel_array * slope) + intercept
130
+
131
+ # 4. Modality-Specific Normalization
132
+ modality = ds.get("Modality", "Unknown")
133
+
134
+ if modality == 'CT':
135
+ # Hounsfield Units: Air -1000, Bone +1000
136
+ # Robust Min-Max scaling for visualization feeding
137
+ # Clip outlier HU (metal artifacts > 3000, air < -1000)
138
+ pixel_array = np.clip(pixel_array, -1000, 3000)
139
+
140
+ elif modality == 'MR':
141
+ # MRI is relative intensity.
142
+ # Simple 1-99 percentile clipping removes spikes.
143
+ p1, p99 = np.percentile(pixel_array, [1, 99])
144
+ pixel_array = np.clip(pixel_array, p1, p99)
145
+
146
+ # 5. Normalization to 0-255 (Display Space)
147
+ pixel_min = np.min(pixel_array)
148
+ pixel_max = np.max(pixel_array)
149
+
150
+ if pixel_max - pixel_min != 0:
151
+ pixel_array = ((pixel_array - pixel_min) / (pixel_max - pixel_min)) * 255.0
152
+ else:
153
+ pixel_array = np.zeros_like(pixel_array)
154
+
155
+ pixel_array = pixel_array.astype(np.uint8)
156
+
157
+ # 6. Color Space
158
+ if len(pixel_array.shape) == 2:
159
+ image = Image.fromarray(pixel_array).convert("RGB")
160
+ else:
161
+ image = Image.fromarray(pixel_array)
162
+
163
+ return image
164
+
165
+ except Exception as e:
166
+ logger.error(f"DICOM Conversion Error: {e}")
167
+ raise ValueError(f"Could not convert DICOM to image: {e}")
encryption.py CHANGED
@@ -1,71 +1,71 @@
1
- from cryptography.fernet import Fernet
2
- import os
3
- import sys
4
- import logging
5
- from typing import Optional
6
-
7
- # -------------------------------------------------------------------------
8
- # ENCRYPTION CONFIGURATION - PRODUCTION READY
9
- # -------------------------------------------------------------------------
10
-
11
- # Environment detection
12
- ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
13
- IS_PRODUCTION = ENVIRONMENT == "production"
14
-
15
- # Encryption Key - Load from environment variable
16
- ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY")
17
-
18
- if not ENCRYPTION_KEY:
19
- if IS_PRODUCTION:
20
- logging.critical("🔴 FATAL ERROR: ENCRYPTION_KEY must be set in production environment")
21
- logging.critical("Generate one with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'")
22
- sys.exit(1) # Fail-fast in production
23
- else:
24
- # Development fallback with ephemeral key
25
- ENCRYPTION_KEY = Fernet.generate_key().decode()
26
- logging.warning("⚠️ WARNING: Using ephemeral encryption key (development only)")
27
-
28
- # Initialize cipher
29
- cipher_suite = Fernet(ENCRYPTION_KEY.encode() if isinstance(ENCRYPTION_KEY, str) else ENCRYPTION_KEY)
30
-
31
- def encrypt_data(data: str) -> str:
32
- """
33
- Encrypts a string and returns the encrypted token as a string.
34
- """
35
- if not data: return ""
36
- encrypted_bytes = cipher_suite.encrypt(data.encode('utf-8'))
37
- return encrypted_bytes.decode('utf-8')
38
-
39
- def decrypt_data(token: str) -> Optional[str]:
40
- """
41
- Decrypts a token and returns the original string.
42
- """
43
- if not token: return None
44
- try:
45
- decrypted_bytes = cipher_suite.decrypt(token.encode('utf-8'))
46
- return decrypted_bytes.decode('utf-8')
47
- except Exception as e:
48
- print(f"Decryption failed: {e}")
49
- return None
50
-
51
- def rotate_key():
52
- """
53
- Example function to rotate keys (advanced).
54
- """
55
- global key, cipher_suite
56
- key = Fernet.generate_key()
57
- cipher_suite = Fernet(key)
58
- with open(ENCRYPTION_KEY_PATH, "wb") as key_file:
59
- key_file.write(key)
60
- print(f"New key generated and saved to {ENCRYPTION_KEY_PATH}")
61
-
62
- if __name__ == "__main__":
63
- # Test
64
- original = "Jean Dupont - Patient Zero"
65
- encrypted = encrypt_data(original)
66
- decrypted = decrypt_data(encrypted)
67
-
68
- print(f"Original: {original}")
69
- print(f"Encrypted: {encrypted}")
70
- print(f"Decrypted: {decrypted}")
71
- assert original == decrypted
 
1
+ from cryptography.fernet import Fernet
2
+ import os
3
+ import sys
4
+ import logging
5
+ from typing import Optional
6
+
7
+ # -------------------------------------------------------------------------
8
+ # ENCRYPTION CONFIGURATION - PRODUCTION READY
9
+ # -------------------------------------------------------------------------
10
+
11
+ # Environment detection
12
+ ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
13
+ IS_PRODUCTION = ENVIRONMENT == "production"
14
+
15
+ # Encryption Key - Load from environment variable
16
+ ENCRYPTION_KEY = os.getenv("ENCRYPTION_KEY")
17
+
18
+ if not ENCRYPTION_KEY:
19
+ if IS_PRODUCTION:
20
+ logging.critical("🔴 FATAL ERROR: ENCRYPTION_KEY must be set in production environment")
21
+ logging.critical("Generate one with: python -c 'from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())'")
22
+ sys.exit(1) # Fail-fast in production
23
+ else:
24
+ # Development fallback with ephemeral key
25
+ ENCRYPTION_KEY = Fernet.generate_key().decode()
26
+ logging.warning("⚠️ WARNING: Using ephemeral encryption key (development only)")
27
+
28
+ # Initialize cipher
29
+ cipher_suite = Fernet(ENCRYPTION_KEY.encode() if isinstance(ENCRYPTION_KEY, str) else ENCRYPTION_KEY)
30
+
31
+ def encrypt_data(data: str) -> str:
32
+ """
33
+ Encrypts a string and returns the encrypted token as a string.
34
+ """
35
+ if not data: return ""
36
+ encrypted_bytes = cipher_suite.encrypt(data.encode('utf-8'))
37
+ return encrypted_bytes.decode('utf-8')
38
+
39
+ def decrypt_data(token: str) -> Optional[str]:
40
+ """
41
+ Decrypts a token and returns the original string.
42
+ """
43
+ if not token: return None
44
+ try:
45
+ decrypted_bytes = cipher_suite.decrypt(token.encode('utf-8'))
46
+ return decrypted_bytes.decode('utf-8')
47
+ except Exception as e:
48
+ print(f"Decryption failed: {e}")
49
+ return None
50
+
51
+ def rotate_key():
52
+ """
53
+ Example function to rotate keys (advanced).
54
+ """
55
+ global key, cipher_suite
56
+ key = Fernet.generate_key()
57
+ cipher_suite = Fernet(key)
58
+ with open(ENCRYPTION_KEY_PATH, "wb") as key_file:
59
+ key_file.write(key)
60
+ print(f"New key generated and saved to {ENCRYPTION_KEY_PATH}")
61
+
62
+ if __name__ == "__main__":
63
+ # Test
64
+ original = "Jean Dupont - Patient Zero"
65
+ encrypted = encrypt_data(original)
66
+ decrypted = decrypt_data(encrypted)
67
+
68
+ print(f"Original: {original}")
69
+ print(f"Encrypted: {encrypted}")
70
+ print(f"Decrypted: {decrypted}")
71
+ assert original == decrypted
explainability.py CHANGED
@@ -1,462 +1,378 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- import cv2
6
- from PIL import Image
7
- import logging
8
- from typing import List, Dict, Any, Optional, Tuple, Union
9
- from pytorch_grad_cam import GradCAMPlusPlus
10
- from pytorch_grad_cam.utils.image import show_cam_on_image
11
- from dataclasses import dataclass
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- # =========================================================================
16
- # CONFIGURATION & EXPERT KNOWLEDGE
17
- # =========================================================================
18
-
19
- @dataclass
20
- class ExpertSegConfig:
21
- modality: str
22
- target_organ: str
23
- anatomical_prompts: List[str] # For Segmentation Mask
24
- threshold_percentile: int # Top X% activation
25
- min_area_ratio: float
26
- max_area_ratio: float
27
- morphology_kernel: int
28
-
29
- # Expert Knowledge Base
30
- EXPERT_KNOWLEDGE = {
31
- "Thoracic": ExpertSegConfig(
32
- modality="CXR/CT",
33
- target_organ="Lung Parenchyma",
34
- anatomical_prompts=[
35
- "lung parenchyma",
36
- "bilateral lungs",
37
- "pulmonary fields",
38
- "chest x-ray lungs excluding heart"
39
- ],
40
- threshold_percentile=75, # Top 25%
41
- min_area_ratio=0.15,
42
- max_area_ratio=0.60,
43
- morphology_kernel=7
44
- ),
45
- "Orthopedics": ExpertSegConfig(
46
- modality="X-Ray",
47
- target_organ="Bone Structure",
48
- anatomical_prompts=[
49
- "bone structure",
50
- "knee joint",
51
- "cortical bone",
52
- "skeletal anatomy"
53
- ],
54
- threshold_percentile=85, # Top 15%
55
- min_area_ratio=0.05,
56
- max_area_ratio=0.50,
57
- morphology_kernel=5
58
- ),
59
- "Default": ExpertSegConfig(
60
- modality="General",
61
- target_organ="Body Part",
62
- anatomical_prompts=["medical image body part"],
63
- threshold_percentile=80,
64
- min_area_ratio=0.05,
65
- max_area_ratio=0.90,
66
- morphology_kernel=5
67
- )
68
- }
69
-
70
- # =========================================================================
71
- # WRAPPERS AND UTILS
72
- # =========================================================================
73
-
74
- class HuggingFaceWeirdCLIPWrapper(nn.Module):
75
- """
76
- Wraps SigLIP to act like a standard classifier for Grad-CAM.
77
- Target: Cosine Similarity Score.
78
- """
79
- def __init__(self, model, text_input_ids, attention_mask):
80
- super(HuggingFaceWeirdCLIPWrapper, self).__init__()
81
- self.model = model
82
- self.text_input_ids = text_input_ids
83
- self.attention_mask = attention_mask
84
-
85
- def forward(self, pixel_values):
86
- outputs = self.model(
87
- pixel_values=pixel_values,
88
- input_ids=self.text_input_ids,
89
- attention_mask=self.attention_mask
90
- )
91
- # outputs.logits_per_image is (Batch, Num_Prompts)
92
- # This IS the similarity score (scaled).
93
- # Grad-CAM++ will derive gradients relative to this score.
94
- return outputs.logits_per_image
95
-
96
- def reshape_transform(tensor, width=32, height=32):
97
- """Reshape Transformer attention/embeddings for Grad-CAM."""
98
- # Squeeze CLS if present logic (usually SigLIP doesn't have it in last layers same way)
99
- # Tensor: (Batch, Num_Tokens, Dim)
100
- num_tokens = tensor.size(1)
101
- side = int(np.sqrt(num_tokens))
102
- result = tensor.reshape(tensor.size(0), side, side, tensor.size(2))
103
- # Bring channels first: (B, C, H, W)
104
- result = result.transpose(2, 3).transpose(1, 2)
105
- return result
106
-
107
- # =========================================================================
108
- # EXPERT+ EXPLAINABILITY ENGINE
109
- # =========================================================================
110
-
111
- class ExplainabilityEngine:
112
- def __init__(self, model_wrapper):
113
- self.wrapper = model_wrapper
114
- self.model = model_wrapper.model
115
- self.processor = model_wrapper.processor
116
- self.device = self.model.device
117
-
118
- def _get_expert_config(self, anatomical_context: str) -> ExpertSegConfig:
119
- if "lung" in anatomical_context.lower():
120
- return EXPERT_KNOWLEDGE["Thoracic"]
121
- elif "bone" in anatomical_context.lower() or "knee" in anatomical_context.lower():
122
- return EXPERT_KNOWLEDGE["Orthopedics"]
123
- else:
124
- base = EXPERT_KNOWLEDGE["Default"]
125
- base.anatomical_prompts = [anatomical_context]
126
- return base
127
-
128
- def generate_expert_mask(self, image: Image.Image, config: ExpertSegConfig) -> Dict[str, Any]:
129
- """
130
- Expert Segmentation:
131
- Multi-Prompt Ensembling -> Patch Similarity -> Adaptive Threshold -> Morphology -> Validation.
132
- """
133
- audit = {
134
- "seg_prompts": config.anatomical_prompts,
135
- "seg_status": "INIT"
136
- }
137
- try:
138
- w, h = image.size
139
- inputs = self.processor(text=config.anatomical_prompts, images=image, padding="max_length", return_tensors="pt")
140
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
141
-
142
- with torch.no_grad():
143
- # Vision Features (1, Token, Dim)
144
- vision_outputs = self.model.vision_model(
145
- pixel_values=inputs["pixel_values"],
146
- output_hidden_states=True
147
- )
148
- last_hidden_state = vision_outputs.last_hidden_state
149
-
150
- # Text Features (Prompts, Dim)
151
- # Text Features (Prompts, Dim)
152
- # FIX: Robustly handle attention_mask (some processors don't return it for text-only inputs if irrelevant)
153
- text_inputs_ids = inputs["input_ids"]
154
- text_attention_mask = inputs.get("attention_mask")
155
-
156
- if text_attention_mask is None:
157
- text_attention_mask = torch.ones_like(text_inputs_ids)
158
-
159
- text_outputs = self.model.text_model(
160
- input_ids=text_inputs_ids,
161
- attention_mask=text_attention_mask
162
- )
163
- text_embeds = text_outputs.pooler_output
164
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
165
-
166
- # Similarity: (1, T, D) @ (D, P) -> (1, T, P)
167
- sim_map = torch.matmul(last_hidden_state, text_embeds.t())
168
- # Mean across Prompts -> (1, T)
169
- sim_map = sim_map.mean(dim=2)
170
-
171
- # Reshape & Upscale
172
- num_tokens = sim_map.size(1)
173
- side = int(np.sqrt(num_tokens))
174
- sim_grid = sim_map.reshape(1, side, side)
175
-
176
- sim_grid = torch.nn.functional.interpolate(
177
- sim_grid.unsqueeze(0),
178
- size=(h, w),
179
- mode='bilinear',
180
- align_corners=False
181
- ).squeeze().cpu().numpy()
182
-
183
- # Adaptive Thresholding (Percentile)
184
- thresh = np.percentile(sim_grid, config.threshold_percentile)
185
- binary_mask = (sim_grid > thresh).astype(np.float32)
186
- audit["seg_threshold"] = float(thresh)
187
-
188
- # Morphological Cleaning
189
- kernel = np.ones((config.morphology_kernel, config.morphology_kernel), np.uint8)
190
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) # Remove noise
191
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) # Fill holes
192
- binary_mask = cv2.GaussianBlur(binary_mask, (15, 15), 0) # Smooth contours
193
- binary_mask = (binary_mask - binary_mask.min()) / (binary_mask.max() - binary_mask.min() + 1e-8)
194
-
195
- # Validation
196
- val = self._validate_mask(binary_mask, config)
197
- audit["seg_validation"] = val
198
-
199
- if not val["valid"]:
200
- logger.warning(f"Mask Invalid: {val['reason']}")
201
- return {"mask": None, "audit": audit}
202
-
203
- return {"mask": binary_mask, "audit": audit}
204
-
205
- except Exception as e:
206
- logger.error(f"Segmentation Failed: {e}")
207
- audit["seg_error"] = str(e)
208
- return {"mask": None, "audit": audit}
209
-
210
- def _validate_mask(self, mask: np.ndarray, config: ExpertSegConfig) -> Dict[str, Any]:
211
- area_ratio = np.sum(mask > 0.5) / mask.size
212
-
213
- if area_ratio < config.min_area_ratio:
214
- return {"valid": False, "reason": f"Small Area: {area_ratio:.2f} < {config.min_area_ratio}"}
215
- if area_ratio > config.max_area_ratio:
216
- return {"valid": False, "reason": f"Large Area: {area_ratio:.2f} > {config.max_area_ratio}"}
217
-
218
- # Connectivity Check (Constraint: "suppression du bruit bas" / continuity)
219
- # Ensure we have large connected components, not confetti
220
- # For now, strict Area check + Opening usually covers this.
221
- return {"valid": True}
222
-
223
- def generate_expert_gradcam(self, image: Image.Image, target_prompts: List[str]) -> Dict[str, Any]:
224
- """
225
- Expert Grad-CAM:
226
- 1. Multi-Prompt Ensembling (Averaging heatmaps).
227
- 2. Layer Selection: Encoder Layer -2.
228
- 3. Target: Cosine Score.
229
- """
230
- audit = {"gradcam_prompts": target_prompts, "gradcam_status": "INIT"}
231
-
232
- try:
233
- # Prepare Inputs
234
- inputs = self.processor(text=target_prompts, images=image, padding="max_length", return_tensors="pt")
235
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
236
-
237
- # Robust Mask handling
238
- input_ids = inputs.get('input_ids')
239
- attention_mask = inputs.get('attention_mask')
240
- if attention_mask is None and input_ids is not None:
241
- attention_mask = torch.ones_like(input_ids)
242
-
243
- # Wrapper
244
- model_wrapper_cam = HuggingFaceWeirdCLIPWrapper(self.model, input_ids, attention_mask)
245
-
246
- # Layer Selection: 2nd to last encoder layer (Better spatial features than last Norm)
247
- # SigLIP structure: model.vision_model.encoder.layers
248
- target_layers = [self.model.vision_model.encoder.layers[-2].layer_norm1]
249
-
250
- cam = GradCAMPlusPlus(
251
- model=model_wrapper_cam,
252
- target_layers=target_layers,
253
- reshape_transform=reshape_transform # Needs to handle (B, T, D)
254
- )
255
-
256
- pixel_values = inputs.get('pixel_values')
257
-
258
- # ENSEMBLING GRAD-CAM
259
- # We want to run Grad-CAM for EACH prompt index and average them.
260
- # Grayscale CAM output is (Batch, H, W)
261
- # We assume Batch=1 here.
262
-
263
- maps = []
264
- for i in range(len(target_prompts)):
265
- # Target Class Index = i (The index of the prompt in the logits)
266
- # GradCAMPlusPlus targets=[ClassifierOutputTarget(i)]
267
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
268
-
269
- targets = [ClassifierOutputTarget(i)]
270
- grayscale_cam = cam(input_tensor=pixel_values, targets=targets)
271
- maps.append(grayscale_cam[0, :])
272
-
273
- # Average
274
- avg_cam = np.mean(np.array(maps), axis=0)
275
-
276
- # Point 5: Smart Normalization & Thresholding
277
- # "cam = normalize(cam)"
278
- if avg_cam.max() > avg_cam.min():
279
- avg_cam = (avg_cam - avg_cam.min()) / (avg_cam.max() - avg_cam.min())
280
-
281
- # "mask = cam > percentile(cam, 85)" - Removing low confidence noise
282
- # We keep it continuous for heatmap but suppress low values
283
- # Using 80th percentile as soft threshold (User said 85, let's use 80 to be safe but clean)
284
- cam_threshold = np.percentile(avg_cam, 80)
285
- avg_cam[avg_cam < cam_threshold] = 0.0
286
-
287
- # Re-normalize the top 20% to spread 0-1 for visibility
288
- if avg_cam.max() > 0:
289
- avg_cam = avg_cam / avg_cam.max()
290
-
291
- # Smoothing after thresholding to remove jagged edges
292
- avg_cam = cv2.GaussianBlur(avg_cam, (11, 11), 0)
293
-
294
- audit["gradcam_threshold_val"] = float(cam_threshold)
295
-
296
- return {"map": avg_cam, "audit": audit}
297
-
298
- except Exception as e:
299
- logger.error(f"Grad-CAM Failed: {e}")
300
- audit["gradcam_error"] = str(e)
301
- return {"map": None, "audit": audit}
302
-
303
- def explain(self, image: Image.Image, target_text: str, anatomical_context: str) -> Dict[str, Any]:
304
- """
305
- Final Expert Fusion Pipeline.
306
- """
307
- # 0. Setup
308
- config = self._get_expert_config(anatomical_context)
309
-
310
- # 1. Anatomical Mask (Strict Constraint)
311
- seg_res = self.generate_expert_mask(image, config)
312
- mask = seg_res["mask"]
313
- audit = seg_res["audit"]
314
-
315
- if mask is None:
316
- # Strict Safety: No Explanation if Segmentation fails.
317
- return {
318
- "heatmap_array": None,
319
- "heatmap_raw": None,
320
- "reliability_score": 0.0,
321
- "confidence_label": "UNSAFE", # Point 8
322
- "audit": audit,
323
- "display_text": "Validation Anatomique Échouée"
324
- }
325
-
326
- # 2. Attention Map (Multi-Prompt)
327
- # Using list of prompts implies Multi-Prompt Grad-CAM (Point 4)
328
- # We can auto-augment target_text if needed, but for now we trust the input.
329
- gradcam_res = self.generate_expert_gradcam(image, [target_text])
330
- heatmap = gradcam_res["map"]
331
- audit.update(gradcam_res["audit"])
332
-
333
- if heatmap is None:
334
- return {
335
- "heatmap_array": None,
336
- "heatmap_raw": None,
337
- "reliability_score": 0.0,
338
- "confidence_label": "LOW",
339
- "audit": audit,
340
- "display_text": "Attention Insuffisante"
341
- }
342
-
343
- # 3. Constraint Fusion (Point 7)
344
- if mask.shape != heatmap.shape:
345
- mask = cv2.resize(mask, (heatmap.shape[1], heatmap.shape[0]))
346
-
347
- final_map = heatmap * mask
348
-
349
- # 4. Reliability (Point 8)
350
- total = np.sum(heatmap) + 1e-8
351
- retained = np.sum(final_map)
352
- reliability = retained / total
353
-
354
- # Point 9: Responsible Display
355
- confidence = "HIGH" if reliability > 0.6 else "LOW"
356
- # FIX: JSON Serialization Error (np.float32 -> float)
357
- audit["reliability_score"] = round(float(reliability), 4)
358
-
359
- # 5. Visualize
360
- img_np = np.array(image)
361
-
362
- # FIX: Ensure img_np is float32 [0,1]
363
- img_np = img_np.astype(np.float32) / 255.0
364
-
365
- # FIX: Resize final_map (Heatmap) to match Original Image Size
366
- # show_cam_on_image requires heatmap and image to be same shape
367
- if final_map.shape != img_np.shape[:2]:
368
- final_map = cv2.resize(final_map, (img_np.shape[1], img_np.shape[0]))
369
-
370
- visualization = show_cam_on_image(img_np, final_map, use_rgb=True)
371
-
372
- return {
373
- "heatmap_array": visualization,
374
- "heatmap_raw": final_map,
375
- # FIX: Cast to float for JSON safety
376
- "reliability_score": round(float(reliability), 2),
377
- "confidence_label": confidence,
378
- "display_text": "Zone d'attention du modèle (Grad-CAM++)"
379
- }
380
-
381
- def calculate_cardiothoracic_ratio(self, image: Image.Image) -> Dict[str, Any]:
382
- """
383
- Morphology Engine: Calculate Heart/Thorax Ratio (CTR).
384
-
385
- Algorithm:
386
- 1. Segment Heart (Prompt: 'heart silhouette')
387
- 2. Segment Lungs (Prompt: 'lungs thoracic cage')
388
- 3. Calculate Max Width of Heart Mask.
389
- 4. Calculate Max Width of Lung Mask (at Costophrenic angle ideally, but Max Width is proxy).
390
- 5. Ratio = Heart / Lungs.
391
- """
392
- audit = {"ctr_status": "INIT"}
393
-
394
- try:
395
- # 1. Heart Segmentation
396
- heart_config = ExpertSegConfig(
397
- modality="CXR",
398
- target_organ="Heart",
399
- anatomical_prompts=["heart silhouette", "cardiac shadow", "mediastinum"],
400
- threshold_percentile=85, # Heart is salient
401
- min_area_ratio=0.05,
402
- max_area_ratio=0.40,
403
- morphology_kernel=5
404
- )
405
- heart_res = self.generate_expert_mask(image, heart_config)
406
- heart_mask = heart_res["mask"]
407
-
408
- if heart_mask is None:
409
- return {"ctr": 0.0, "valid": False, "reason": "Heart segmentation failed"}
410
-
411
- # 2. Lung/Thorax Segmentation
412
- lung_config = ExpertSegConfig(
413
- modality="CXR",
414
- target_organ="Thorax",
415
- anatomical_prompts=["lung fields", "thoracic cage", "rib cage", "diaphragm"],
416
- threshold_percentile=75,
417
- min_area_ratio=0.20,
418
- max_area_ratio=0.85,
419
- morphology_kernel=5
420
- )
421
- lung_res = self.generate_expert_mask(image, lung_config)
422
- lung_mask = lung_res["mask"]
423
-
424
- if lung_mask is None:
425
- return {"ctr": 0.0, "valid": False, "reason": "Lung segmentation failed"}
426
-
427
- # 3. Calculate Widths
428
- # Sum along Vertical Axis (0) -> shape (Width,)
429
- # Pixels > 0.5 count as "structure"
430
-
431
- # Heart Width
432
- heart_proj = np.max(heart_mask, axis=0) # [0, 1] projection
433
- heart_pixels = np.where(heart_proj > 0.5)[0]
434
- if len(heart_pixels) == 0:
435
- return {"ctr": 0.0, "valid": False, "reason": "Empty heart mask"}
436
- heart_width = heart_pixels.max() - heart_pixels.min()
437
-
438
- # Lung Width
439
- lung_proj = np.max(lung_mask, axis=0)
440
- lung_pixels = np.where(lung_proj > 0.5)[0]
441
- if len(lung_pixels) == 0:
442
- return {"ctr": 0.0, "valid": False, "reason": "Empty lung mask"}
443
- lung_width = lung_pixels.max() - lung_pixels.min()
444
-
445
- # 4. Compute Ratio
446
- if lung_width == 0:
447
- return {"ctr": 0.0, "valid": False, "reason": "Zero lung width"}
448
-
449
- ctr = heart_width / lung_width
450
- logger.info(f"📐 Morphology Engine: Heart={heart_width}px, Lungs={lung_width}px, CTR={ctr:.2f}")
451
-
452
- return {
453
- "ctr": round(float(ctr), 2),
454
- "heart_width_px": int(heart_width),
455
- "lung_width_px": int(lung_width),
456
- "valid": True,
457
- "reason": "Success"
458
- }
459
-
460
- except Exception as e:
461
- logger.error(f"CTR Calculation Failed: {e}")
462
- return {"ctr": 0.0, "valid": False, "reason": str(e)}
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+ import logging
8
+ from typing import List, Dict, Any, Optional, Tuple, Union
9
+ from pytorch_grad_cam import GradCAMPlusPlus
10
+ from pytorch_grad_cam.utils.image import show_cam_on_image
11
+ from dataclasses import dataclass
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # =========================================================================
16
+ # CONFIGURATION & EXPERT KNOWLEDGE
17
+ # =========================================================================
18
+
19
+ @dataclass
20
+ class ExpertSegConfig:
21
+ modality: str
22
+ target_organ: str
23
+ anatomical_prompts: List[str] # For Segmentation Mask
24
+ threshold_percentile: int # Top X% activation
25
+ min_area_ratio: float
26
+ max_area_ratio: float
27
+ morphology_kernel: int
28
+
29
+ # Expert Knowledge Base
30
+ EXPERT_KNOWLEDGE = {
31
+ "Thoracic": ExpertSegConfig(
32
+ modality="CXR/CT",
33
+ target_organ="Lung Parenchyma",
34
+ anatomical_prompts=[
35
+ "lung parenchyma",
36
+ "bilateral lungs",
37
+ "pulmonary fields",
38
+ "chest x-ray lungs excluding heart"
39
+ ],
40
+ threshold_percentile=75, # Top 25%
41
+ min_area_ratio=0.15,
42
+ max_area_ratio=0.60,
43
+ morphology_kernel=7
44
+ ),
45
+ "Orthopedics": ExpertSegConfig(
46
+ modality="X-Ray",
47
+ target_organ="Bone Structure",
48
+ anatomical_prompts=[
49
+ "bone structure",
50
+ "knee joint",
51
+ "cortical bone",
52
+ "skeletal anatomy"
53
+ ],
54
+ threshold_percentile=85, # Top 15%
55
+ min_area_ratio=0.05,
56
+ max_area_ratio=0.50,
57
+ morphology_kernel=5
58
+ ),
59
+ "Default": ExpertSegConfig(
60
+ modality="General",
61
+ target_organ="Body Part",
62
+ anatomical_prompts=["medical image body part"],
63
+ threshold_percentile=80,
64
+ min_area_ratio=0.05,
65
+ max_area_ratio=0.90,
66
+ morphology_kernel=5
67
+ )
68
+ }
69
+
70
+ # =========================================================================
71
+ # WRAPPERS AND UTILS
72
+ # =========================================================================
73
+
74
+ class HuggingFaceWeirdCLIPWrapper(nn.Module):
75
+ """
76
+ Wraps SigLIP to act like a standard classifier for Grad-CAM.
77
+ Target: Cosine Similarity Score.
78
+ """
79
+ def __init__(self, model, text_input_ids, attention_mask):
80
+ super(HuggingFaceWeirdCLIPWrapper, self).__init__()
81
+ self.model = model
82
+ self.text_input_ids = text_input_ids
83
+ self.attention_mask = attention_mask
84
+
85
+ def forward(self, pixel_values):
86
+ outputs = self.model(
87
+ pixel_values=pixel_values,
88
+ input_ids=self.text_input_ids,
89
+ attention_mask=self.attention_mask
90
+ )
91
+ # outputs.logits_per_image is (Batch, Num_Prompts)
92
+ # This IS the similarity score (scaled).
93
+ # Grad-CAM++ will derive gradients relative to this score.
94
+ return outputs.logits_per_image
95
+
96
+ def reshape_transform(tensor, width=32, height=32):
97
+ """Reshape Transformer attention/embeddings for Grad-CAM."""
98
+ # Squeeze CLS if present logic (usually SigLIP doesn't have it in last layers same way)
99
+ # Tensor: (Batch, Num_Tokens, Dim)
100
+ num_tokens = tensor.size(1)
101
+ side = int(np.sqrt(num_tokens))
102
+ result = tensor.reshape(tensor.size(0), side, side, tensor.size(2))
103
+ # Bring channels first: (B, C, H, W)
104
+ result = result.transpose(2, 3).transpose(1, 2)
105
+ return result
106
+
107
+ # =========================================================================
108
+ # EXPERT+ EXPLAINABILITY ENGINE
109
+ # =========================================================================
110
+
111
+ class ExplainabilityEngine:
112
+ def __init__(self, model_wrapper):
113
+ self.wrapper = model_wrapper
114
+ self.model = model_wrapper.model
115
+ self.processor = model_wrapper.processor
116
+ self.device = self.model.device
117
+
118
+ def _get_expert_config(self, anatomical_context: str) -> ExpertSegConfig:
119
+ if "lung" in anatomical_context.lower():
120
+ return EXPERT_KNOWLEDGE["Thoracic"]
121
+ elif "bone" in anatomical_context.lower() or "knee" in anatomical_context.lower():
122
+ return EXPERT_KNOWLEDGE["Orthopedics"]
123
+ else:
124
+ base = EXPERT_KNOWLEDGE["Default"]
125
+ base.anatomical_prompts = [anatomical_context]
126
+ return base
127
+
128
+ def generate_expert_mask(self, image: Image.Image, config: ExpertSegConfig) -> Dict[str, Any]:
129
+ """
130
+ Expert Segmentation:
131
+ Multi-Prompt Ensembling -> Patch Similarity -> Adaptive Threshold -> Morphology -> Validation.
132
+ """
133
+ audit = {
134
+ "seg_prompts": config.anatomical_prompts,
135
+ "seg_status": "INIT"
136
+ }
137
+ try:
138
+ w, h = image.size
139
+ inputs = self.processor(text=config.anatomical_prompts, images=image, padding="max_length", return_tensors="pt")
140
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
141
+
142
+ with torch.no_grad():
143
+ # Vision Features (1, Token, Dim)
144
+ vision_outputs = self.model.vision_model(
145
+ pixel_values=inputs["pixel_values"],
146
+ output_hidden_states=True
147
+ )
148
+ last_hidden_state = vision_outputs.last_hidden_state
149
+
150
+ # Text Features (Prompts, Dim)
151
+ # Text Features (Prompts, Dim)
152
+ # FIX: Robustly handle attention_mask (some processors don't return it for text-only inputs if irrelevant)
153
+ text_inputs_ids = inputs["input_ids"]
154
+ text_attention_mask = inputs.get("attention_mask")
155
+
156
+ if text_attention_mask is None:
157
+ text_attention_mask = torch.ones_like(text_inputs_ids)
158
+
159
+ text_outputs = self.model.text_model(
160
+ input_ids=text_inputs_ids,
161
+ attention_mask=text_attention_mask
162
+ )
163
+ text_embeds = text_outputs.pooler_output
164
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
165
+
166
+ # Similarity: (1, T, D) @ (D, P) -> (1, T, P)
167
+ sim_map = torch.matmul(last_hidden_state, text_embeds.t())
168
+ # Mean across Prompts -> (1, T)
169
+ sim_map = sim_map.mean(dim=2)
170
+
171
+ # Reshape & Upscale
172
+ num_tokens = sim_map.size(1)
173
+ side = int(np.sqrt(num_tokens))
174
+ sim_grid = sim_map.reshape(1, side, side)
175
+
176
+ sim_grid = torch.nn.functional.interpolate(
177
+ sim_grid.unsqueeze(0),
178
+ size=(h, w),
179
+ mode='bilinear',
180
+ align_corners=False
181
+ ).squeeze().cpu().numpy()
182
+
183
+ # Adaptive Thresholding (Percentile)
184
+ thresh = np.percentile(sim_grid, config.threshold_percentile)
185
+ binary_mask = (sim_grid > thresh).astype(np.float32)
186
+ audit["seg_threshold"] = float(thresh)
187
+
188
+ # Morphological Cleaning
189
+ kernel = np.ones((config.morphology_kernel, config.morphology_kernel), np.uint8)
190
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) # Remove noise
191
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) # Fill holes
192
+ binary_mask = cv2.GaussianBlur(binary_mask, (15, 15), 0) # Smooth contours
193
+ binary_mask = (binary_mask - binary_mask.min()) / (binary_mask.max() - binary_mask.min() + 1e-8)
194
+
195
+ # Validation
196
+ val = self._validate_mask(binary_mask, config)
197
+ audit["seg_validation"] = val
198
+
199
+ if not val["valid"]:
200
+ logger.warning(f"Mask Invalid: {val['reason']}")
201
+ return {"mask": None, "audit": audit}
202
+
203
+ return {"mask": binary_mask, "audit": audit}
204
+
205
+ except Exception as e:
206
+ logger.error(f"Segmentation Failed: {e}")
207
+ audit["seg_error"] = str(e)
208
+ return {"mask": None, "audit": audit}
209
+
210
+ def _validate_mask(self, mask: np.ndarray, config: ExpertSegConfig) -> Dict[str, Any]:
211
+ area_ratio = np.sum(mask > 0.5) / mask.size
212
+
213
+ if area_ratio < config.min_area_ratio:
214
+ return {"valid": False, "reason": f"Small Area: {area_ratio:.2f} < {config.min_area_ratio}"}
215
+ if area_ratio > config.max_area_ratio:
216
+ return {"valid": False, "reason": f"Large Area: {area_ratio:.2f} > {config.max_area_ratio}"}
217
+
218
+ # Connectivity Check (Constraint: "suppression du bruit bas" / continuity)
219
+ # Ensure we have large connected components, not confetti
220
+ # For now, strict Area check + Opening usually covers this.
221
+ return {"valid": True}
222
+
223
+ def generate_expert_gradcam(self, image: Image.Image, target_prompts: List[str]) -> Dict[str, Any]:
224
+ """
225
+ Expert Grad-CAM:
226
+ 1. Multi-Prompt Ensembling (Averaging heatmaps).
227
+ 2. Layer Selection: Encoder Layer -2.
228
+ 3. Target: Cosine Score.
229
+ """
230
+ audit = {"gradcam_prompts": target_prompts, "gradcam_status": "INIT"}
231
+
232
+ try:
233
+ # Prepare Inputs
234
+ inputs = self.processor(text=target_prompts, images=image, padding="max_length", return_tensors="pt")
235
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
236
+
237
+ # Robust Mask handling
238
+ input_ids = inputs.get('input_ids')
239
+ attention_mask = inputs.get('attention_mask')
240
+ if attention_mask is None and input_ids is not None:
241
+ attention_mask = torch.ones_like(input_ids)
242
+
243
+ # Wrapper
244
+ model_wrapper_cam = HuggingFaceWeirdCLIPWrapper(self.model, input_ids, attention_mask)
245
+
246
+ # Layer Selection: 2nd to last encoder layer (Better spatial features than last Norm)
247
+ # SigLIP structure: model.vision_model.encoder.layers
248
+ target_layers = [self.model.vision_model.encoder.layers[-2].layer_norm1]
249
+
250
+ cam = GradCAMPlusPlus(
251
+ model=model_wrapper_cam,
252
+ target_layers=target_layers,
253
+ reshape_transform=reshape_transform # Needs to handle (B, T, D)
254
+ )
255
+
256
+ pixel_values = inputs.get('pixel_values')
257
+
258
+ # ENSEMBLING GRAD-CAM
259
+ # We want to run Grad-CAM for EACH prompt index and average them.
260
+ # Grayscale CAM output is (Batch, H, W)
261
+ # We assume Batch=1 here.
262
+
263
+ maps = []
264
+ for i in range(len(target_prompts)):
265
+ # Target Class Index = i (The index of the prompt in the logits)
266
+ # GradCAMPlusPlus targets=[ClassifierOutputTarget(i)]
267
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
268
+
269
+ targets = [ClassifierOutputTarget(i)]
270
+ grayscale_cam = cam(input_tensor=pixel_values, targets=targets)
271
+ maps.append(grayscale_cam[0, :])
272
+
273
+ # Average
274
+ avg_cam = np.mean(np.array(maps), axis=0)
275
+
276
+ # Point 5: Smart Normalization & Thresholding
277
+ # "cam = normalize(cam)"
278
+ if avg_cam.max() > avg_cam.min():
279
+ avg_cam = (avg_cam - avg_cam.min()) / (avg_cam.max() - avg_cam.min())
280
+
281
+ # "mask = cam > percentile(cam, 85)" - Removing low confidence noise
282
+ # We keep it continuous for heatmap but suppress low values
283
+ # Using 80th percentile as soft threshold (User said 85, let's use 80 to be safe but clean)
284
+ cam_threshold = np.percentile(avg_cam, 80)
285
+ avg_cam[avg_cam < cam_threshold] = 0.0
286
+
287
+ # Re-normalize the top 20% to spread 0-1 for visibility
288
+ if avg_cam.max() > 0:
289
+ avg_cam = avg_cam / avg_cam.max()
290
+
291
+ # Smoothing after thresholding to remove jagged edges
292
+ avg_cam = cv2.GaussianBlur(avg_cam, (11, 11), 0)
293
+
294
+ audit["gradcam_threshold_val"] = float(cam_threshold)
295
+
296
+ return {"map": avg_cam, "audit": audit}
297
+
298
+ except Exception as e:
299
+ logger.error(f"Grad-CAM Failed: {e}")
300
+ audit["gradcam_error"] = str(e)
301
+ return {"map": None, "audit": audit}
302
+
303
+ def explain(self, image: Image.Image, target_text: str, anatomical_context: str) -> Dict[str, Any]:
304
+ """
305
+ Final Expert Fusion Pipeline.
306
+ """
307
+ # 0. Setup
308
+ config = self._get_expert_config(anatomical_context)
309
+
310
+ # 1. Anatomical Mask (Strict Constraint)
311
+ seg_res = self.generate_expert_mask(image, config)
312
+ mask = seg_res["mask"]
313
+ audit = seg_res["audit"]
314
+
315
+ if mask is None:
316
+ # Strict Safety: No Explanation if Segmentation fails.
317
+ return {
318
+ "heatmap_array": None,
319
+ "heatmap_raw": None,
320
+ "reliability_score": 0.0,
321
+ "confidence_label": "UNSAFE", # Point 8
322
+ "audit": audit,
323
+ "display_text": "Validation Anatomique Échouée"
324
+ }
325
+
326
+ # 2. Attention Map (Multi-Prompt)
327
+ # Using list of prompts implies Multi-Prompt Grad-CAM (Point 4)
328
+ # We can auto-augment target_text if needed, but for now we trust the input.
329
+ gradcam_res = self.generate_expert_gradcam(image, [target_text])
330
+ heatmap = gradcam_res["map"]
331
+ audit.update(gradcam_res["audit"])
332
+
333
+ if heatmap is None:
334
+ return {
335
+ "heatmap_array": None,
336
+ "heatmap_raw": None,
337
+ "reliability_score": 0.0,
338
+ "confidence_label": "LOW",
339
+ "audit": audit,
340
+ "display_text": "Attention Insuffisante"
341
+ }
342
+
343
+ # 3. Constraint Fusion (Point 7)
344
+ if mask.shape != heatmap.shape:
345
+ mask = cv2.resize(mask, (heatmap.shape[1], heatmap.shape[0]))
346
+
347
+ final_map = heatmap * mask
348
+
349
+ # 4. Reliability (Point 8)
350
+ total = np.sum(heatmap) + 1e-8
351
+ retained = np.sum(final_map)
352
+ reliability = retained / total
353
+
354
+ # Point 9: Responsible Display
355
+ confidence = "HIGH" if reliability > 0.6 else "LOW"
356
+ audit["reliability_score"] = round(reliability, 4)
357
+
358
+ # 5. Visualize
359
+ img_np = np.array(image)
360
+
361
+ # FIX: Ensure img_np is float32 [0,1]
362
+ img_np = img_np.astype(np.float32) / 255.0
363
+
364
+ # FIX: Resize final_map (Heatmap) to match Original Image Size
365
+ # show_cam_on_image requires heatmap and image to be same shape
366
+ if final_map.shape != img_np.shape[:2]:
367
+ final_map = cv2.resize(final_map, (img_np.shape[1], img_np.shape[0]))
368
+
369
+ visualization = show_cam_on_image(img_np, final_map, use_rgb=True)
370
+
371
+ return {
372
+ "heatmap_array": visualization,
373
+ "heatmap_raw": final_map,
374
+ "reliability_score": round(reliability, 2),
375
+ "confidence_label": confidence,
376
+ "audit": audit,
377
+ "display_text": "Zone d'attention du modèle (Grad-CAM++)"
378
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
localization.py CHANGED
@@ -1,83 +1,83 @@
1
- # Mappings de localisation (Anglais -> Français)
2
- # Ce fichier permet de traduire les résultats de l'IA sans modifier les prompts originaux
3
- # qui doivent rester en anglais pour la performance du modèle.
4
-
5
- DOMAIN_TRANSLATIONS = {
6
- 'Thoracic': {
7
- 'label': 'Thoracique',
8
- 'description': 'Analyse Radiographique du Thorax'
9
- },
10
- 'Dermatology': {
11
- 'label': 'Dermatologie',
12
- 'description': 'Analyse Dermatoscope des Lésions Cutanées'
13
- },
14
- 'Histology': {
15
- 'label': 'Histologie',
16
- 'description': 'Analyse Microscopique (H&E)'
17
- },
18
- 'Ophthalmology': {
19
- 'label': 'Ophtalmologie',
20
- 'description': 'Fond d\'Oeil (Rétine)'
21
- },
22
- 'Orthopedics': {
23
- 'label': 'Orthopédie',
24
- 'description': 'Radiographie Osseuse'
25
- }
26
- }
27
-
28
- LABEL_TRANSLATIONS = {
29
- # --- THORACIC ---
30
- 'Diffuse interstitial opacities or ground-glass pattern (Viral/Atypical Pneumonia)':
31
- 'Opacités interstitielles diffuses ou aspect en verre dépoli (Pneumonie Virale/Atypique)',
32
-
33
- 'Focal alveolar consolidation with air bronchograms (Bacterial Pneumonia)':
34
- 'Condensation alvéolaire focale avec bronchogrammes aériens (Pneumonie Bactérienne)',
35
-
36
- 'Perfectly clear lungs, sharp costophrenic angles, no pathology':
37
- 'Poumons parfaitement clairs, angles costophréniques nets, aucune pathologie',
38
-
39
- 'Pneumothorax (Lung collapse)': 'Pneumothorax (Décollement de la plèvre)',
40
- 'Pleural Effusion (Fluid)': 'Épanchement Pleural (Liquide)',
41
- 'Cardiomegaly (Enlarged heart)': 'Cardiomégalie (Cœur élargi)',
42
- 'Pulmonary Edema': 'Œdème Pulmonaire',
43
- 'Lung Nodule or Mass': 'Nodule ou Masse Pulmonaire',
44
- 'Atelectasis (Lung collapse)': 'Atélectasie (Affaissement pulmonaire)',
45
-
46
- # --- DERMATOLOGY ---
47
- 'A healthy skin area without lesion': 'Zone de peau saine sans lésion',
48
- 'A benign nevus (mole) regular, symmetrical and homogeneous': 'Nævus bénin (grain de beauté) régulier, symétrique et homogène',
49
- 'A seborrheic keratosis (benign warty lesion)': 'Kératose séborrhéique (lésion verruqueuse bénigne)',
50
- 'A malignant melanoma with asymmetry, irregular borders and multiple colors': 'Mélanome malin (Asymétrie, Bords irréguliers, Couleurs multiples)',
51
- 'A basal cell carcinoma (pearly or ulcerated lesion)': 'Carcinome basocellulaire (lésion perlée ou ulcérée)',
52
- 'A squamous cell carcinoma (crusty or budding lesion)': 'Carcinome épidermoïde (lésion croûteuse ou bourgeonnante)',
53
- 'A non-specific inflammatory skin lesion': 'Lésion cutanée inflammatoire non spécifique',
54
-
55
- # --- ORTHOPEDICS ---
56
- 'Severe osteoarthritis with bone-on-bone contact and large osteophytes (Grade 4)': 'Arthrose sévère avec contact os-contre-os et ostéophytes importants (Grade 4)',
57
- 'Moderate osteoarthritis with definite joint space narrowing (Grade 2-3)': 'Arthrose modérée avec pincement articulaire net (Grade 2-3)',
58
- 'Normal knee joint with preserved joint space and no osteophytes (Grade 0-1)': 'Genou normal, interligne articulaire préservé (Grade 0-1)',
59
- 'Total knee arthroplasty (TKA) with metallic implant': 'Prothèse totale de genou (implant métallique)',
60
- 'Acute knee fracture or dislocation': 'Fracture ou luxation aiguë du genou',
61
- 'Other x-ray view (Chest, Hand, Foot, Pediatric) - OUT OF DISTRIBUTION': 'Autre vue radiographique (Hors périmètre)',
62
- 'A knee x-ray view (Knee Joint)': 'Radiographie du Genou'
63
- }
64
-
65
- def localize_result(result_json):
66
- """
67
- Traduit les résultats bruts (Anglais) en Français
68
- en utilisant les dictionnaires de mapping.
69
- """
70
- # 1. Localiser le Domaine
71
- domain_key = result_json['domain']['label']
72
- if domain_key in DOMAIN_TRANSLATIONS:
73
- result_json['domain']['label'] = DOMAIN_TRANSLATIONS[domain_key]['label']
74
- result_json['domain']['description'] = DOMAIN_TRANSLATIONS[domain_key]['description']
75
-
76
- # 2. Localiser les Résultats Spécifiques
77
- for item in result_json['specific']:
78
- original_label = item['label']
79
- if original_label in LABEL_TRANSLATIONS:
80
- item['label'] = LABEL_TRANSLATIONS[original_label]
81
- # Si pas de traduction trouvée, on garde l'anglais (fallback)
82
-
83
- return result_json
 
1
+ # Mappings de localisation (Anglais -> Français)
2
+ # Ce fichier permet de traduire les résultats de l'IA sans modifier les prompts originaux
3
+ # qui doivent rester en anglais pour la performance du modèle.
4
+
5
+ DOMAIN_TRANSLATIONS = {
6
+ 'Thoracic': {
7
+ 'label': 'Thoracique',
8
+ 'description': 'Analyse Radiographique du Thorax'
9
+ },
10
+ 'Dermatology': {
11
+ 'label': 'Dermatologie',
12
+ 'description': 'Analyse Dermatoscope des Lésions Cutanées'
13
+ },
14
+ 'Histology': {
15
+ 'label': 'Histologie',
16
+ 'description': 'Analyse Microscopique (H&E)'
17
+ },
18
+ 'Ophthalmology': {
19
+ 'label': 'Ophtalmologie',
20
+ 'description': 'Fond d\'Oeil (Rétine)'
21
+ },
22
+ 'Orthopedics': {
23
+ 'label': 'Orthopédie',
24
+ 'description': 'Radiographie Osseuse'
25
+ }
26
+ }
27
+
28
+ LABEL_TRANSLATIONS = {
29
+ # --- THORACIC ---
30
+ 'Diffuse interstitial opacities or ground-glass pattern (Viral/Atypical Pneumonia)':
31
+ 'Opacités interstitielles diffuses ou aspect en verre dépoli (Pneumonie Virale/Atypique)',
32
+
33
+ 'Focal alveolar consolidation with air bronchograms (Bacterial Pneumonia)':
34
+ 'Condensation alvéolaire focale avec bronchogrammes aériens (Pneumonie Bactérienne)',
35
+
36
+ 'Perfectly clear lungs, sharp costophrenic angles, no pathology':
37
+ 'Poumons parfaitement clairs, angles costophréniques nets, aucune pathologie',
38
+
39
+ 'Pneumothorax (Lung collapse)': 'Pneumothorax (Décollement de la plèvre)',
40
+ 'Pleural Effusion (Fluid)': 'Épanchement Pleural (Liquide)',
41
+ 'Cardiomegaly (Enlarged heart)': 'Cardiomégalie (Cœur élargi)',
42
+ 'Pulmonary Edema': 'Œdème Pulmonaire',
43
+ 'Lung Nodule or Mass': 'Nodule ou Masse Pulmonaire',
44
+ 'Atelectasis (Lung collapse)': 'Atélectasie (Affaissement pulmonaire)',
45
+
46
+ # --- DERMATOLOGY ---
47
+ 'A healthy skin area without lesion': 'Zone de peau saine sans lésion',
48
+ 'A benign nevus (mole) regular, symmetrical and homogeneous': 'Nævus bénin (grain de beauté) régulier, symétrique et homogène',
49
+ 'A seborrheic keratosis (benign warty lesion)': 'Kératose séborrhéique (lésion verruqueuse bénigne)',
50
+ 'A malignant melanoma with asymmetry, irregular borders and multiple colors': 'Mélanome malin (Asymétrie, Bords irréguliers, Couleurs multiples)',
51
+ 'A basal cell carcinoma (pearly or ulcerated lesion)': 'Carcinome basocellulaire (lésion perlée ou ulcérée)',
52
+ 'A squamous cell carcinoma (crusty or budding lesion)': 'Carcinome épidermoïde (lésion croûteuse ou bourgeonnante)',
53
+ 'A non-specific inflammatory skin lesion': 'Lésion cutanée inflammatoire non spécifique',
54
+
55
+ # --- ORTHOPEDICS ---
56
+ 'Severe osteoarthritis with bone-on-bone contact and large osteophytes (Grade 4)': 'Arthrose sévère avec contact os-contre-os et ostéophytes importants (Grade 4)',
57
+ 'Moderate osteoarthritis with definite joint space narrowing (Grade 2-3)': 'Arthrose modérée avec pincement articulaire net (Grade 2-3)',
58
+ 'Normal knee joint with preserved joint space and no osteophytes (Grade 0-1)': 'Genou normal, interligne articulaire préservé (Grade 0-1)',
59
+ 'Total knee arthroplasty (TKA) with metallic implant': 'Prothèse totale de genou (implant métallique)',
60
+ 'Acute knee fracture or dislocation': 'Fracture ou luxation aiguë du genou',
61
+ 'Other x-ray view (Chest, Hand, Foot, Pediatric) - OUT OF DISTRIBUTION': 'Autre vue radiographique (Hors périmètre)',
62
+ 'A knee x-ray view (Knee Joint)': 'Radiographie du Genou'
63
+ }
64
+
65
+ def localize_result(result_json):
66
+ """
67
+ Traduit les résultats bruts (Anglais) en Français
68
+ en utilisant les dictionnaires de mapping.
69
+ """
70
+ # 1. Localiser le Domaine
71
+ domain_key = result_json['domain']['label']
72
+ if domain_key in DOMAIN_TRANSLATIONS:
73
+ result_json['domain']['label'] = DOMAIN_TRANSLATIONS[domain_key]['label']
74
+ result_json['domain']['description'] = DOMAIN_TRANSLATIONS[domain_key]['description']
75
+
76
+ # 2. Localiser les Résultats Spécifiques
77
+ for item in result_json['specific']:
78
+ original_label = item['label']
79
+ if original_label in LABEL_TRANSLATIONS:
80
+ item['label'] = LABEL_TRANSLATIONS[original_label]
81
+ # Si pas de traduction trouvée, on garde l'anglais (fallback)
82
+
83
+ return result_json
main.py CHANGED
The diff for this file is too large to render. See raw diff
 
medical_labels.py DELETED
@@ -1,307 +0,0 @@
1
- from typing import Dict, List, Any
2
-
3
- # =========================================================================
4
- # CANONICAL MEDICAL DOMAINS CONFIGURATION (MODEL SOURCE OF TRUTH)
5
- # =========================================================================
6
- # - Prompts must be in ENGLISH (Model Language).
7
- # - Labels must have a stable 'id'.
8
- # - Logic Gates define structural/quality constraints.
9
-
10
- MEDICAL_DOMAINS = {
11
- 'Thoracic': {
12
- 'id': 'DOM_THORACIC',
13
- 'domain_prompt': 'Chest X-Ray Analysis',
14
- 'specific_labels': [
15
- {'id': 'TH_PNEUMONIA_VIRAL', 'label_en': 'Diffuse interstitial opacities or ground-glass pattern (Viral/Atypical Pneumonia)'},
16
- {'id': 'TH_PNEUMONIA_BACT', 'label_en': 'Focal alveolar consolidation with air bronchograms (Bacterial Pneumonia)'},
17
- {'id': 'TH_NORMAL', 'label_en': 'Normal chest radiograph: normal cardiothoracic ratio, clear lungs, no pleural abnormality'},
18
- {'id': 'TH_PNEUMOTHORAX', 'label_en': 'Pneumothorax (Lung collapse)'},
19
- {'id': 'TH_PLEURAL_EFFUSION', 'label_en': 'Pleural Effusion (Fluid)'},
20
- {'id': 'TH_CARDIOMEGALY_CLEAR', 'label_en': 'Cardiomegaly with clear lung fields (no pulmonary edema)'},
21
- {'id': 'TH_CARDIOMEGALY_EDEMA', 'label_en': 'Cardiomegaly with pulmonary congestion or edema'},
22
- {'id': 'TH_EDEMA', 'label_en': 'Pulmonary Edema (without cardiomegaly)'},
23
- {'id': 'TH_NODULE', 'label_en': 'Lung Nodule or Mass'},
24
- {'id': 'TH_ATELECTASIS', 'label_en': 'Atelectasis (Lung collapse)'}
25
- ],
26
- 'logic_gate': {
27
- 'prompt': 'Evaluate cardiac silhouette size',
28
- 'labels': ['Normal cardiac size (CTR < 0.5)', 'Enlarged cardiac silhouette (Cardiomegaly)'],
29
- 'penalty_target': 'TH_NORMAL', # Penalize the ID of the normal label
30
- 'abnormal_index': 1
31
- }
32
- },
33
- 'Dermatology': {
34
- 'id': 'DOM_DERMATOLOGY',
35
- 'domain_prompt': 'Dermatoscopic analysis of a pigmented or non-pigmented skin lesion',
36
- 'specific_labels': [
37
- {'id': 'DERM_NORMAL', 'label_en': 'Normal skin without visible lesion or abnormal pigmentation'},
38
- {'id': 'DERM_NEVUS', 'label_en': 'Benign melanocytic nevus with symmetry and uniform pigmentation'},
39
- {'id': 'DERM_SEBORRHEIC', 'label_en': 'Seborrheic keratosis (benign warty lesion)'},
40
- {'id': 'DERM_MELANOMA', 'label_en': 'Malignant melanoma with asymmetry, irregular borders, and color variegation'},
41
- {'id': 'DERM_BCC', 'label_en': 'Basal cell carcinoma (pearly or ulcerated lesion)'},
42
- {'id': 'DERM_SCC', 'label_en': 'Squamous cell carcinoma (crusty or budding lesion)'},
43
- {'id': 'DERM_INFLAMMATORY', 'label_en': 'Inflammatory skin lesion (Eczema, Psoriasis)'}
44
- ],
45
- 'logic_gate': {
46
- 'prompt': 'Is there a visible skin lesion?',
47
- 'labels': ['No visible skin lesion', 'Visible skin lesion (pigmented or non-pigmented)'],
48
- 'penalty_target': 'ALL_PATHOLOGY',
49
- 'abnormal_index': 0
50
- }
51
- },
52
- 'Histology': {
53
- 'id': 'DOM_HISTOLOGY',
54
- 'domain_prompt': 'Microscopic analysis of a histological section (H&E stain)',
55
- 'specific_labels': [
56
- {'id': 'HIST_HEALTHY_BREAST', 'label_en': 'Healthy breast tissue with preserved lobular architecture'},
57
- {'id': 'HIST_HEALTHY_PROSTATE', 'label_en': 'Healthy prostatic tissue with regular glands'},
58
- {'id': 'HIST_IDC_BREAST', 'label_en': 'Invasive ductal carcinoma (Disorganized cells)'},
59
- {'id': 'HIST_ADENO_PROSTATE', 'label_en': 'Prostate adenocarcinoma (Gland fusion)'},
60
- {'id': 'HIST_DYSPLASIA', 'label_en': 'Cervical dysplasia or intraepithelial neoplasia'},
61
- {'id': 'HIST_COLON_CA', 'label_en': 'Colon cancer tumor tissue'},
62
- {'id': 'HIST_LUNG_CA', 'label_en': 'Lung cancer tumor tissue'},
63
- {'id': 'HIST_ADIPOSE', 'label_en': 'Adipose tissue (Fat) or connective stroma'},
64
- {'id': 'HIST_ARTIFACT', 'label_en': 'Preparation artifact, empty area, or blurred region'}
65
- ],
66
- 'logic_gate': {
67
- 'prompt': 'Assess histological validity of the image',
68
- 'labels': ['Adequate H&E tissue section', 'Artifact, empty area, or blurred region'],
69
- 'penalty_target': 'ALL_DIAGNOSIS',
70
- 'abnormal_index': 1
71
- }
72
- },
73
- 'Ophthalmology': {
74
- 'id': 'DOM_OPHTHALMOLOGY',
75
- 'domain_prompt': 'Fundus photography (Retina)',
76
- 'specific_labels': [
77
- {'id': 'OPH_NORMAL', 'label_en': 'Normal retina with visible optic disc and macula'},
78
- {'id': 'OPH_DIABETIC', 'label_en': 'Diabetic retinopathy (hemorrhages, exudates)'},
79
- {'id': 'OPH_GLAUCOMA', 'label_en': 'Glaucoma (optic disc cupping)'},
80
- {'id': 'OPH_AMD', 'label_en': 'Macular degeneration (drusen or atrophy)'}
81
- ],
82
- 'logic_gate': {
83
- 'prompt': 'Is the fundus image clinically interpretable?',
84
- 'labels': ['Good quality fundus image', 'Poor quality, uninterpretable or partial view'],
85
- 'penalty_target': 'ALL_DIAGNOSIS',
86
- 'abnormal_index': 1
87
- }
88
- },
89
- 'Orthopedics': {
90
- 'id': 'DOM_ORTHOPEDICS',
91
- 'domain_prompt': 'Bone X-Ray (Musculoskeletal)',
92
- 'stage_1_triage': {
93
- 'prompt': 'Anatomical region identification',
94
- 'labels': [
95
- 'Other x-ray view (Chest, Hand, Foot, Pediatric) - OUT OF DISTRIBUTION',
96
- 'A knee x-ray view (Knee Joint)'
97
- ]
98
- },
99
- 'specific_labels': [
100
- {'id': 'ORTH_OA_SEVERE', 'label_en': 'Severe osteoarthritis (Grade 4)'},
101
- {'id': 'ORTH_OA_MODERATE', 'label_en': 'Moderate osteoarthritis (Grade 2-3)'},
102
- {'id': 'ORTH_NORMAL', 'label_en': 'Normal knee'},
103
- {'id': 'ORTH_IMPLANT', 'label_en': 'Implant'}
104
- ],
105
- 'stage_2_diagnosis': {
106
- 'prompt': 'Knee Osteoarthritis Severity Assessment',
107
- 'labels': [
108
- {'id': 'ORTH_OA_SEVERE', 'label_en': 'Severe osteoarthritis with bone-on-bone contact (Grade 4)'},
109
- {'id': 'ORTH_OA_MODERATE', 'label_en': 'Moderate osteoarthritis with definite joint space narrowing (Grade 2-3)'},
110
- {'id': 'ORTH_NORMAL', 'label_en': 'Normal knee joint with preserved joint space (Grade 0-1)'},
111
- {'id': 'ORTH_IMPLANT', 'label_en': 'Total knee arthroplasty (TKA) with metallic implant'},
112
- {'id': 'ORTH_FRACTURE', 'label_en': 'Acute knee fracture or dislocation'}
113
- ]
114
- },
115
- 'logic_gate': {
116
- 'prompt': 'Is there a metallic implant?',
117
- 'labels': ['Native knee joint', 'Knee with metallic implant (Arthroplasty)'],
118
- 'penalty_target': 'ORTH_OA', # Logic target string match (Prefix)
119
- 'abnormal_index': 1
120
- }
121
- }
122
- }
123
-
124
- # =========================================================================
125
- # FRENCH TRANSLATIONS (USER INTERFACE ONLY)
126
- # =========================================================================
127
- # - Strict Mapping: ID -> {title, description}
128
- # - No dynamic translation allowed.
129
-
130
- LABEL_TRANSLATIONS_FR = {
131
- # --- THORACIC ---
132
- 'TH_NORMAL': {
133
- 'short': 'Thorax sans anomalie',
134
- 'long': 'Silhouette cardiaque normale, poumons clairs, pas d’épanchement.',
135
- 'severity': 'low'
136
- },
137
- 'TH_PNEUMONIA_VIRAL': {
138
- 'short': 'Pneumonie Virale / Atypique',
139
- 'long': 'Opacités interstitielles diffuses ou verre dépoli.',
140
- 'severity': 'high'
141
- },
142
- 'TH_PNEUMONIA_BACT': {
143
- 'short': 'Pneumonie Bactérienne',
144
- 'long': 'Consolidation alvéolaire focale avec bronchogramme aérien.',
145
- 'severity': 'high'
146
- },
147
- 'TH_PNEUMOTHORAX': {
148
- 'short': 'Pneumothorax',
149
- 'long': 'Présence possible d’air dans la cavité pleurale (collapsus).',
150
- 'severity': 'emergency'
151
- },
152
- 'TH_PLEURAL_EFFUSION': {
153
- 'short': 'Épanchement Pleural',
154
- 'long': 'Accumulation de liquide dans l’espace pleural.',
155
- 'severity': 'medium'
156
- },
157
- 'TH_CARDIOMEGALY_CLEAR': { # UPDATED ID
158
- 'short': 'Cardiomégalie (Poumons clairs)',
159
- 'long': 'Silhouette cardiaque augmentée de taille sans signe d’œdème pulmonaire.',
160
- 'severity': 'medium'
161
- },
162
- 'TH_CARDIOMEGALY_EDEMA': {
163
- 'short': 'Cardiomégalie avec Stase',
164
- 'long': 'Cœur augmenté de taille associé à une congestion pulmonaire.',
165
- 'severity': 'high'
166
- },
167
- 'TH_EDEMA': {
168
- 'short': 'Œdème Pulmonaire',
169
- 'long': 'Surcharge liquidienne pulmonaire (sans cardiomégalie évidente).',
170
- 'severity': 'high'
171
- },
172
- 'TH_NODULE': {
173
- 'short': 'Nodule ou Masse Pulmonaire',
174
- 'long': 'Lésion focale suspecte nécessitant un scanner de contrôle.',
175
- 'severity': 'high'
176
- },
177
- 'TH_ATELECTASIS': {
178
- 'short': 'Atélectasie',
179
- 'long': 'Affaissement d’une partie du poumon.',
180
- 'severity': 'medium'
181
- },
182
-
183
- # --- DERMATOLOGY ---
184
- 'DERM_NORMAL': {
185
- 'short': 'Peau saine / Pas de lésion',
186
- 'long': 'Aucune lésion dermatologique suspecte visible.',
187
- 'severity': 'low'
188
- },
189
- 'DERM_NEVUS': {
190
- 'short': 'Nævus Bénin (Grain de beauté)',
191
- 'long': 'Lésion régulière, symétrique et homogène.',
192
- 'severity': 'low'
193
- },
194
- 'DERM_SEBORRHEIC': {
195
- 'short': 'Kératose Séborrhéique',
196
- 'long': 'Lésion bénigne fréquente ("verrue de vieillesse").',
197
- 'severity': 'low'
198
- },
199
- 'DERM_MELANOMA': {
200
- 'short': 'Suspicion de Mélanome',
201
- 'long': 'Lésion pigmentée asymétrique, bords irréguliers (critères ABCDE). Urgence.',
202
- 'severity': 'emergency'
203
- },
204
- 'DERM_BCC': {
205
- 'short': 'Carcinome Basocellulaire',
206
- 'long': 'Lésion perlée ou ulcérée suggérant un carcinome non-mélanique.',
207
- 'severity': 'high'
208
- },
209
- 'DERM_SCC': {
210
- 'short': 'Carcinome Épidermoïde',
211
- 'long': 'Lésion croûteuse ou bourgeonnante suspecte.',
212
- 'severity': 'high'
213
- },
214
- 'DERM_INFLAMMATORY': {
215
- 'short': 'Lésion Inflammatoire',
216
- 'long': 'Aspect compatible avec eczéma, psoriasis ou dermatite.',
217
- 'severity': 'medium'
218
- },
219
-
220
- # --- HISTOLOGY ---
221
- 'HIST_ARTIFACT': {
222
- 'short': 'Qualité Insuffisante (Artefact)',
223
- 'long': 'Tissu non interprétable (section vide, floue ou artefact technique).',
224
- 'severity': 'none'
225
- },
226
- 'HIST_HEALTHY_BREAST': {
227
- 'short': 'Tissu Mammaire Sain',
228
- 'long': 'Architecture lobulaire préservée.',
229
- 'severity': 'low'
230
- },
231
- 'HIST_IDC_BREAST': {
232
- 'short': 'Carcinome Canalaire Infiltrant',
233
- 'long': 'Prolifération cellulaire désorganisée invasive (Sein).',
234
- 'severity': 'high'
235
- },
236
- 'HIST_HEALTHY_PROSTATE': {
237
- 'short': 'Tissu Prostatique Sain',
238
- 'long': 'Glandes régulières, stroma normal.',
239
- 'severity': 'low'
240
- },
241
- 'HIST_ADENO_PROSTATE': {
242
- 'short': 'Adénocarcinome Prostatique',
243
- 'long': 'Fusion glandulaire et atypies cytonucléaires.',
244
- 'severity': 'high'
245
- },
246
- 'HIST_COLON_CA': {'short': 'Cancer Colorectal', 'long': 'Tissu tumoral colique.', 'severity': 'high'},
247
- 'HIST_LUNG_CA': {'short': 'Cancer Pulmonaire', 'long': 'Tissu tumoral pulmonaire.', 'severity': 'high'},
248
- 'HIST_DYSPLASIA': {'short': 'Dysplasie / CIN', 'long': 'Anomalies précancéreuses.', 'severity': 'medium'},
249
- 'HIST_ADIPOSE': {'short': 'Tissu Adipeux / Stroma', 'long': 'Tissu de soutien normal.', 'severity': 'low'},
250
-
251
- # --- OPHTHALMOLOGY ---
252
- 'OPH_NORMAL': {
253
- 'short': 'Fond d’œil Normal',
254
- 'long': 'Rétine, macula et papille d’aspect sain.',
255
- 'severity': 'low'
256
- },
257
- 'OPH_DIABETIC': {
258
- 'short': 'Rétinopathie Diabétique',
259
- 'long': 'Présence d’hémorragies, exsudats ou anévrismes.',
260
- 'severity': 'high'
261
- },
262
- 'OPH_GLAUCOMA': {
263
- 'short': 'Suspicion de Glaucome',
264
- 'long': 'Excavation papillaire (cup/disc ratio) augmentée.',
265
- 'severity': 'high'
266
- },
267
- 'OPH_AMD': {
268
- 'short': 'DMLA',
269
- 'long': 'Dégénérescence Maculaire (drusens ou atrophie).',
270
- 'severity': 'medium'
271
- },
272
-
273
- # --- ORTHOPEDICS ---
274
- 'ORTH_NORMAL': {
275
- 'short': 'Genou Normal',
276
- 'long': 'Interligne articulaire préservé, pas d’ostéophyte.',
277
- 'severity': 'low'
278
- },
279
- 'ORTH_OA_MODERATE': {
280
- 'short': 'Arthrose Modérée (Grade 2-3)',
281
- 'long': 'Pincement articulaire visible et ostéophytes.',
282
- 'severity': 'medium'
283
- },
284
- 'ORTH_OA_SEVERE': {
285
- 'short': 'Arthrose Sévère (Grade 4)',
286
- 'long': 'Disparition de l’interligne (os sur os), déformation.',
287
- 'severity': 'high'
288
- },
289
- 'ORTH_IMPLANT': {
290
- 'short': 'Prothèse Totale (PTG)',
291
- 'long': 'Genou avec implant métallique (Arthroplastie).',
292
- 'severity': 'low'
293
- },
294
- 'ORTH_FRACTURE': {
295
- 'short': 'Fracture Récente / Luxation',
296
- 'long': 'Solution de continuité osseuse ou perte de congruence.',
297
- 'severity': 'emergency'
298
- }
299
- }
300
-
301
- DOMAIN_TRANSLATIONS_FR = {
302
- 'Thoracic': 'Radiographie Thoracique',
303
- 'Dermatology': 'Dermatoscopie',
304
- 'Histology': 'Histopathologie (H&E)',
305
- 'Ophthalmology': 'Fond d’Oeil (Rétine)',
306
- 'Orthopedics': 'Radiographie Osseuse'
307
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
quality_control.py DELETED
@@ -1,235 +0,0 @@
1
-
2
- import numpy as np
3
- import cv2
4
- import pydicom
5
- import logging
6
- from typing import Dict, Any, List, Tuple, Union
7
- from PIL import Image
8
-
9
- logger = logging.getLogger("ElephMind-QC")
10
-
11
- class QualityControlEngine:
12
- """
13
- Advanced Quality Control Engine (Gatekeeper).
14
- Implements the 9-Point QC Checklist.
15
-
16
- Metrics:
17
- 1. Structural (DICOM)
18
- 2. Intensity (Contrast)
19
- 3. Blur (Laplacian)
20
- 4. Noise (SNR)
21
- 5. Saturation (Clipping)
22
- 6. Spatial (Aspect Ratio)
23
-
24
- Decision:
25
- QC Score = Weighted Sum
26
- Threshold >= 0.75 -> PASS
27
- """
28
-
29
- def __init__(self):
30
- # Weights defined by user
31
- self.weights = {
32
- "structure": 0.30, # Weight 3 (Normalized approx)
33
- "blur": 0.20, # Weight 2
34
- "contrast": 0.20, # Weight 2
35
- "noise": 0.10, # Weight 1
36
- "saturation": 0.10,
37
- "spatial": 0.10
38
- }
39
- # Thresholds
40
- self.thresholds = {
41
- "blur_var": 100.0, # Laplacian Variance < 100 -> Blurry
42
- "contrast_std": 10.0, # Std Dev < 10 -> Low Contrast
43
- "entropy": 4.0, # Entropy < 4.0 -> Low Info
44
- "snr_min": 2.0, # Signal-to-Noise Ratio < 2.0 -> Noisy
45
- "saturation_max": 0.05, # >5% pixels at min/max -> Saturated
46
- "aspect_min": 0.5, # Too thin
47
- "aspect_max": 2.0 # Too wide
48
- }
49
-
50
- def evaluate_dicom(self, dataset: pydicom.dataset.FileDataset) -> Dict[str, Any]:
51
- """
52
- Gate 1: Structural DICOM Check.
53
- """
54
- reasons = []
55
- passed = True
56
-
57
- try:
58
- # 1. Pixel Data Presence
59
- if not hasattr(dataset, "PixelData") or dataset.PixelData is None:
60
- return {"passed": False, "score": 0.0, "reasons": ["CRITICAL: Missing PixelData"]}
61
-
62
- # 2. Dimensions
63
- rows = getattr(dataset, "Rows", 0)
64
- cols = getattr(dataset, "Columns", 0)
65
- if rows <= 0 or cols <= 0:
66
- return {"passed": False, "score": 0.0, "reasons": ["CRITICAL: Invalid Dimensions (Rows/Cols <= 0)"]}
67
-
68
- # 3. Transfer Syntax (Compression check - basic)
69
- # If we can read pixel_array, it's usually mostly fine, preventing crash is handled in processor.
70
- # Here we just check logical validity.
71
-
72
- pass
73
- except Exception as e:
74
- return {"passed": False, "score": 0.0, "reasons": [f"CRITICAL: DICOM Corrupt ({str(e)})"]}
75
-
76
- return {"passed": True, "score": 1.0, "reasons": []}
77
-
78
- def compute_metrics(self, image: np.ndarray) -> Dict[str, float]:
79
- """
80
- Compute raw metrics for the image (H, W) or (H, W, C).
81
- Image input should be uint8 0-255 or float.
82
- """
83
- metrics = {}
84
-
85
- # Ensure Grayscale for calculation
86
- if len(image.shape) == 3:
87
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
88
- else:
89
- gray = image
90
-
91
- # 1. Blur (Variance of Laplacian)
92
- metrics['blur_var'] = cv2.Laplacian(gray, cv2.CV_64F).var()
93
-
94
- # 2. Intensity / Contrast
95
- metrics['std_dev'] = np.std(gray)
96
- # Entropy
97
- hist, _ = np.histogram(gray, bins=256, range=(0, 256))
98
- prob = hist / (np.sum(hist) + 1e-8)
99
- prob = prob[prob > 0]
100
- metrics['entropy'] = -np.sum(prob * np.log2(prob))
101
-
102
- # 3. Noise (Simple SNR estimate)
103
- # Signal = Mean, Noise = Std(High Pass)
104
- # Simple High Pass: Image - Blurred
105
- blurred = cv2.GaussianBlur(gray, (5, 5), 0)
106
- noise_img = gray.astype(float) - blurred.astype(float)
107
- noise_std = np.std(noise_img) + 1e-8
108
- signal_mean = np.mean(gray)
109
- metrics['snr'] = signal_mean / noise_std
110
-
111
- # 4. Saturation
112
- # % pixels at 0 or 255
113
- n_pixels = gray.size
114
- n_sat = np.sum(gray <= 5) + np.sum(gray >= 250)
115
- metrics['saturation_pct'] = n_sat / n_pixels
116
-
117
- # 5. Spatial
118
- h, w = gray.shape
119
- metrics['aspect_ratio'] = w / h
120
-
121
- return metrics
122
-
123
- def run_quality_check(self, image_input: Union[Image.Image, np.ndarray, pydicom.dataset.FileDataset]) -> Dict[str, Any]:
124
- """
125
- Main Entry Point.
126
- Returns: {
127
- "passed": bool,
128
- "quality_score": float (0-1),
129
- "reasons": List[str],
130
- "metrics": Dict
131
- }
132
- """
133
- reasons = []
134
- scores = {}
135
-
136
- # --- PHASE 1: DICOM STRUCTURE (If DICOM) ---
137
- dicom_score = 1.0
138
- if isinstance(image_input, pydicom.dataset.FileDataset):
139
- res_struct = self.evaluate_dicom(image_input)
140
- if not res_struct['passed']:
141
- return {
142
- "passed": False,
143
- "quality_score": 0.0,
144
- "reasons": res_struct['reasons'],
145
- "metrics": {}
146
- }
147
- # Convert to numpy for image analysis using standard processor logic (simplified here or assume pre-converted)
148
- # ideally the caller passes the converted image.
149
- # If input is DICOM, we assume we can't analyze image metrics easily here without converting.
150
- # To simplify integration: Check DICOM Structure, then rely on caller to pass Image object for Visual QC.
151
- # For this implementation, we assume input is PIL Image or Numpy Array for Visual QC.
152
- pass
153
-
154
- # Prepare Image
155
- if isinstance(image_input, Image.Image):
156
- img_np = np.array(image_input)
157
- elif isinstance(image_input, np.ndarray):
158
- img_np = image_input
159
- else:
160
- # If strictly DICOM passed without conversion capability, we only did struct check
161
- return {"passed": True, "quality_score": 1.0, "reasons": [], "metrics": {}}
162
-
163
- # --- PHASE 2: VISUAL METRICS ---
164
- m = self.compute_metrics(img_np)
165
-
166
- # 1. Blur Check
167
- # Sigmoid-like soft score or Hard Threshold? User implies Hard Rules composed into Score.
168
- # "Structure: weight 3, Blur: weight 2..."
169
- # Let's assign 0 or 1 per category based on threshold, then weight.
170
-
171
- # Blur
172
- if m['blur_var'] < self.thresholds['blur_var']:
173
- scores['blur'] = 0.0
174
- reasons.append("Image Floue (Netteté insuffisante)")
175
- else:
176
- scores['blur'] = 1.0
177
-
178
- # Contrast / Intensity
179
- if m['std_dev'] < self.thresholds['contrast_std'] or m['entropy'] < self.thresholds['entropy']:
180
- scores['contrast'] = 0.0
181
- reasons.append("Contraste Insuffisant (Image plate/sombre)")
182
- else:
183
- scores['contrast'] = 1.0
184
-
185
- # Noise
186
- if m['snr'] < self.thresholds['snr_min']:
187
- scores['noise'] = 0.0
188
- reasons.append("Bruit Excessif (SNR faible)")
189
- else:
190
- scores['noise'] = 1.0
191
-
192
- # Saturation
193
- if m['saturation_pct'] > self.thresholds['saturation_max']:
194
- scores['saturation'] = 0.0
195
- reasons.append("Saturation Excessive (>5% clipping)")
196
- else:
197
- scores['saturation'] = 1.0
198
-
199
- # Spatial
200
- if not (self.thresholds['aspect_min'] <= m['aspect_ratio'] <= self.thresholds['aspect_max']):
201
- scores['spatial'] = 0.0
202
- reasons.append(f"Format Anatomique Invalide (Ratio {m['aspect_ratio']:.2f})")
203
- else:
204
- scores['spatial'] = 1.0
205
-
206
- # Structural (Implicitly 1 if we got here with an image)
207
- scores['structure'] = 1.0
208
-
209
- # --- PHASE 3: GLOBAL SCORE ---
210
- # QC_score = Sum(w * s)
211
- final_score = (
212
- self.weights['structure'] * scores.get('structure', 1.0) +
213
- self.weights['blur'] * scores.get('blur', 1.0) +
214
- self.weights['contrast'] * scores.get('contrast', 1.0) +
215
- self.weights['noise'] * scores.get('noise', 1.0) +
216
- self.weights['saturation'] * scores.get('saturation', 1.0) +
217
- self.weights['spatial'] * scores.get('spatial', 1.0)
218
- )
219
-
220
- # Normalize weights sum just in case
221
- total_weight = sum(self.weights.values())
222
- final_score = final_score / total_weight
223
-
224
- # DECISION
225
- is_passed = final_score >= 0.75
226
-
227
- status = "PASSED" if is_passed else "REJECTED"
228
- logger.info(f"QC Evaluation: {status} (Score: {final_score:.2f}) - Reasons: {reasons}")
229
-
230
- return {
231
- "passed": is_passed,
232
- "quality_score": round(final_score, 2),
233
- "reasons": reasons,
234
- "metrics": m
235
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,21 +1,21 @@
1
- fastapi
2
- uvicorn
3
- python-multipart
4
- requests
5
- transformers
6
- torch
7
- Pillow
8
- sentencepiece
9
- pydicom
10
- numpy
11
- grad-cam
12
- python-jose[cryptography]
13
- passlib
14
- argon2-cffi
15
- bcrypt==4.0.1
16
- cryptography
17
- python-dotenv
18
- opencv-python
19
- python-swiftclient
20
- protobuf
21
- huggingface_hub
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ requests
5
+ transformers
6
+ torch
7
+ Pillow
8
+ sentencepiece
9
+ pydicom
10
+ numpy
11
+ grad-cam
12
+ python-jose[cryptography]
13
+ passlib
14
+ argon2-cffi
15
+ bcrypt==4.0.1
16
+ cryptography
17
+ python-dotenv
18
+ opencv-python
19
+ python-swiftclient
20
+ protobuf
21
+ huggingface_hub
scripts/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ElephMind Utility Scripts
2
+
3
+ This directory contains maintenance and debug scripts for the ElephMind backend.
4
+
5
+ ## How to Run
6
+
7
+ Because these scripts import modules from the parent `server/` directory, you must run them with the parent directory in your `PYTHONPATH`.
8
+
9
+ **Windows (PowerShell):**
10
+ ```powershell
11
+ $env:PYTHONPATH=".."; python init_admin.py
12
+ ```
13
+
14
+ **Linux/Mac:**
15
+ ```bash
16
+ PYTHONPATH=.. python init_admin.py
17
+ ```
18
+
19
+ ## Available Scripts
20
+
21
+ - **`init_admin.py`**: Creates the initial 'admin' user with secure password hashing.
22
+ - **`verify_admin.py`**: Checks if the admin user exists in the database.
23
+ - **`test_auth.py`**: Unit tests for the authentication logic.
24
+ - **`debug_inference.py`**: Tests the ML model with a dummy image.
25
+ - **`inspect_model.py`**: Prints details about the loaded PyTorch model.
scripts/debug_inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModel
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+
6
+ # Configuration
7
+ MODEL_DIR = r"D:\oeil d'elephant"
8
+
9
+ def test_inference():
10
+ print(f"Loading model from {MODEL_DIR}...")
11
+ try:
12
+ model = AutoModel.from_pretrained(MODEL_DIR, local_files_only=True)
13
+ processor = AutoProcessor.from_pretrained(MODEL_DIR, local_files_only=True)
14
+ model.eval()
15
+
16
+ if hasattr(model, 'logit_scale'):
17
+ with torch.no_grad():
18
+ model.logit_scale.data.fill_(4.60517) # exp(4.6) = 100
19
+
20
+ print("Model loaded.")
21
+ except Exception as e:
22
+ print(f"Failed to load model: {e}")
23
+ return
24
+
25
+ # Synthetic Chest X-ray
26
+ image = Image.new('RGB', (448, 448), color=(0, 0, 0))
27
+ draw = ImageDraw.Draw(image)
28
+ draw.ellipse([100, 100, 200, 350], fill=(200, 200, 200))
29
+ draw.ellipse([248, 100, 348, 350], fill=(200, 200, 200)) # Lungs
30
+
31
+ # Simple Prompts Hypothesis
32
+ prompts = [
33
+ 'Os',
34
+ 'Poumons',
35
+ 'Peau',
36
+ 'Oeil',
37
+ 'Sein',
38
+ 'Tissu'
39
+ ]
40
+
41
+ # Also test slightly descriptive
42
+ prompts_v2 = [
43
+ 'Radiographie Os',
44
+ 'Radiographie Poumons',
45
+ 'Photo Peau',
46
+ 'Fond d\'oeil',
47
+ 'Mammographie Sein',
48
+ 'Microscope Tissu'
49
+ ]
50
+
51
+ print("\nTesting Simple Prompts on Synthetic Chest X-ray:")
52
+
53
+ for p_set in [prompts, prompts_v2]:
54
+ with torch.no_grad():
55
+ inputs = processor(text=p_set, images=image, padding="max_length", return_tensors="pt")
56
+ outputs = model(**inputs)
57
+ logits = outputs.logits_per_image
58
+ probs = torch.sigmoid(logits)[0]
59
+
60
+ # Also calculate Softmax
61
+ probs_softmax = torch.softmax(logits, dim=1)[0]
62
+
63
+ for i, prompt in enumerate(p_set):
64
+ l = logits[0][i].item()
65
+ p_sig = probs[i].item()
66
+ p_soft = probs_softmax[i].item()
67
+ print(f"Prompt: '{prompt:<20}' | Logit: {l:.4f} | Sigmoid: {p_sig*100:.6f}% | Softmax: {p_soft*100:.2f}%")
68
+ print("-" * 60)
69
+
70
+ if __name__ == "__main__":
71
+ test_inference()
scripts/debug_pathology.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModel
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+
6
+ # Configuration
7
+ MODEL_DIR = r"D:\oeil d'elephant"
8
+
9
+ def test_inference():
10
+ print(f"Loading model from {MODEL_DIR}...")
11
+ try:
12
+ model = AutoModel.from_pretrained(MODEL_DIR, local_files_only=True)
13
+ processor = AutoProcessor.from_pretrained(MODEL_DIR, local_files_only=True)
14
+ model.eval()
15
+
16
+ # Apply fix
17
+ if hasattr(model, 'logit_scale'):
18
+ with torch.no_grad():
19
+ model.logit_scale.data.fill_(4.60517)
20
+
21
+ print("Model loaded.")
22
+ except Exception as e:
23
+ print(f"Failed to load model: {e}")
24
+ return
25
+
26
+ # Synthetic Pneumonia X-ray
27
+ # Two lungs, one with a big white consolidation
28
+ image = Image.new('RGB', (448, 448), color=(0, 0, 0))
29
+ draw = ImageDraw.Draw(image)
30
+ draw.ellipse([100, 100, 200, 350], fill=(100, 100, 100)) # Left lung (clearer)
31
+ draw.ellipse([248, 100, 348, 350], fill=(200, 200, 200)) # Right lung (consolidated/white)
32
+
33
+ # Check "Thoracic" specific labels
34
+ labels = [
35
+ 'Cardiomédiastin élargi', 'Cardiomégalie', 'Opacité pulmonaire',
36
+ 'Lésion pulmonaire', 'Consolidation', 'Œdème', 'Pneumonie',
37
+ 'Atelectasis', 'Pneumothorax', 'Effusion pleurale', 'Pleural Autre'
38
+ ]
39
+
40
+ # Try simplified versions too
41
+ simple_labels = [
42
+ 'Coeur', 'Gros coeur', 'Opacité',
43
+ 'Lésion', 'Blanc', 'Eau', 'Infection',
44
+ 'Ecrasé', 'Air', 'Liquide', 'Autre'
45
+ ]
46
+
47
+ print("\nTesting Pathology Prompts:")
48
+
49
+ with torch.no_grad():
50
+ inputs = processor(text=labels, images=image, padding="max_length", return_tensors="pt")
51
+ outputs = model(**inputs)
52
+ logits = outputs.logits_per_image
53
+ probs = torch.sigmoid(logits)[0]
54
+
55
+ print("\nOriginal Labels:")
56
+ for i, label in enumerate(labels):
57
+ print(f"'{label}': Logit {logits[0][i]:.4f} | Prob {probs[i]:.6f}")
58
+
59
+ # Test Simple
60
+ inputs_simple = processor(text=simple_labels, images=image, padding="max_length", return_tensors="pt")
61
+ outputs_simple = model(**inputs_simple)
62
+ logits_simple = outputs_simple.logits_per_image
63
+ probs_simple = torch.sigmoid(logits_simple)[0]
64
+
65
+ print("\nSimple Labels:")
66
+ for i, label in enumerate(simple_labels):
67
+ print(f"'{label}': Logit {logits_simple[0][i]:.4f} | Prob {probs_simple[0][i]:.6f}")
68
+
69
+ if __name__ == "__main__":
70
+ test_inference()
scripts/init_admin.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
4
+ import database
5
+ from main import get_password_hash
6
+
7
+ def create_admin():
8
+ database.init_db()
9
+ if database.get_user_by_username("admin"):
10
+ print("Admin already exists.")
11
+ return
12
+
13
+ admin_data = {
14
+ "username": "admin",
15
+ "hashed_password": get_password_hash("password123"),
16
+ "email": "admin@elephmind.com",
17
+ "security_question": "Quel est votre animal totem ?",
18
+ "security_answer": get_password_hash("elephant")
19
+ }
20
+
21
+ if database.create_user(admin_data):
22
+ print("Admin user created successfully. (Login: admin / password123)")
23
+ else:
24
+ print("Failed to create admin user.")
25
+
26
+ if __name__ == "__main__":
27
+ create_admin()
scripts/inspect_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ MODEL_DIR = r"D:\oeil d'elephant"
5
+
6
+ def inspect():
7
+ files = ["config.json", "preprocessor_config.json", "tokenizer_config.json"]
8
+
9
+ for f in files:
10
+ path = os.path.join(MODEL_DIR, f)
11
+ print(f"\n--- {f} ---")
12
+ if os.path.exists(path):
13
+ try:
14
+ with open(path, 'r', encoding='utf-8') as file:
15
+ content = json.load(file)
16
+ # Print summary to avoid huge output
17
+ if f == "config.json":
18
+ print(json.dumps({k:v for k,v in content.items() if k in ['architectures', 'model_type', 'logit_scale_init_value', 'vision_config', 'text_config']}, indent=2))
19
+ elif f == "preprocessor_config.json":
20
+ print(json.dumps(content, indent=2))
21
+ else:
22
+ print(json.dumps(content, indent=2))
23
+ except Exception as e:
24
+ print(f"Error reading {f}: {e}")
25
+ else:
26
+ print("File not found.")
27
+
28
+ if __name__ == "__main__":
29
+ inspect()
scripts/list_patients.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Script to verify patients in DB
3
+
4
+ import sqlite3
5
+ import os
6
+
7
+ DB_NAME = "elephmind.db"
8
+ if os.path.exists('/data/elephmind.db'):
9
+ DB_NAME = '/data/elephmind.db'
10
+
11
+ def list_patients():
12
+ if not os.path.exists(DB_NAME):
13
+ print(f"Database {DB_NAME} not found.")
14
+ return
15
+
16
+ conn = sqlite3.connect(DB_NAME)
17
+ conn.row_factory = sqlite3.Row
18
+ c = conn.cursor()
19
+ try:
20
+ c.execute("SELECT * FROM patients")
21
+ rows = c.fetchall()
22
+ print(f"Found {len(rows)} patients.")
23
+ for row in rows:
24
+ print(dict(row))
25
+ except Exception as e:
26
+ print(f"Error: {e}")
27
+ finally:
28
+ conn.close()
29
+
30
+ if __name__ == "__main__":
31
+ list_patients()
scripts/test_auth.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import sys
3
+
4
+ BASE_URL = "http://127.0.0.1:8022"
5
+
6
+ def test_health():
7
+ print(f"Testing Health Check at {BASE_URL}/health...")
8
+ try:
9
+ r = requests.get(f"{BASE_URL}/health")
10
+ if r.status_code == 200:
11
+ print("✅ Health Check Passed")
12
+ return True
13
+ except Exception as e:
14
+ print(f"❌ Health Check Failed: {e}")
15
+ return False
16
+
17
+ def test_auth():
18
+ print("Testing Authentication...")
19
+
20
+ # 1. Try to access protected route without token
21
+ try:
22
+ r = requests.post(f"{BASE_URL}/analyze")
23
+ if r.status_code == 401:
24
+ print("✅ Protected Endpoint correctly rejected unauthorized request (401)")
25
+ else:
26
+ print(f"❌ Protected Endpoint Failed: Expected 401, got {r.status_code}")
27
+ return False
28
+
29
+ # 2. Login to get token
30
+ payload = {"username": "admin", "password": "secret"}
31
+ r = requests.post(f"{BASE_URL}/token", data=payload)
32
+ if r.status_code == 200:
33
+ token = r.json().get("access_token")
34
+ if token:
35
+ print("✅ Login Successful. Token received.")
36
+ else:
37
+ print("❌ Login Failed: No token in response")
38
+ return False
39
+ else:
40
+ print(f"❌ Login Failed: {r.status_code} - {r.text}")
41
+ return False
42
+
43
+ # 3. Access protected route WITH token (Should fail on 422 Validation 'Field required' for file, NOT 401)
44
+ headers = {"Authorization": f"Bearer {token}"}
45
+ # We don't send file, expecting 422 Unprocessable Entity (Missing File), which means Auth passed!
46
+ r = requests.post(f"{BASE_URL}/analyze", headers=headers)
47
+ if r.status_code == 422:
48
+ print("✅ Protected Endpoint correctly accepted token (Got 422 for missing file, not 401)")
49
+ return True
50
+ elif r.status_code == 401:
51
+ print("❌ Protected Endpoint rejected valid token (401)")
52
+ return False
53
+ else:
54
+ print(f"⚠️ Unexpected status with token: {r.status_code}")
55
+ return True # Acceptable for now
56
+
57
+ except Exception as e:
58
+ print(f"❌ Test Exception: {e}")
59
+ return False
60
+
61
+ if __name__ == "__main__":
62
+ if test_health():
63
+ test_auth()
scripts/verify_admin.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import sys
4
+ import os
5
+
6
+ # Add server directory to path to import database
7
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+ import database
10
+
11
+ def check_admin():
12
+ print(f"Checking database: {database.DB_NAME}")
13
+
14
+ # Initialize DB if tables missing (which seems to be the case in this context)
15
+ database.init_db()
16
+
17
+ try:
18
+ user = database.get_user_by_username("admin")
19
+ if user:
20
+ print("USER 'admin' FOUND.")
21
+ print(f" ID: {user['id']}")
22
+ print(f" Email: {user['email']}")
23
+ else:
24
+ print("USER 'admin' NOT FOUND.")
25
+ except Exception as e:
26
+ print(f"Error querying database: {e}")
27
+
28
+ if __name__ == "__main__":
29
+ check_admin()
30
+
secret.key ADDED
@@ -0,0 +1 @@
 
 
1
+ 6cfBgfzHb12RD2eW_9QxGrpoDdScGYoqpV3MYvz96LE=
storage.py CHANGED
@@ -1,92 +1,92 @@
1
- import os
2
- import abc
3
- from datetime import datetime
4
-
5
- class StorageProvider(abc.ABC):
6
- @abc.abstractmethod
7
- def save_file(self, file_bytes: bytes, filename: str) -> str:
8
- pass
9
-
10
- @abc.abstractmethod
11
- def get_file(self, filename: str) -> bytes:
12
- pass
13
-
14
- class LocalStorage(StorageProvider):
15
- def __init__(self, base_dir="data_storage"):
16
- self.base_dir = base_dir
17
- os.makedirs(base_dir, exist_ok=True)
18
-
19
- def save_file(self, file_bytes: bytes, filename: str) -> str:
20
- # Prepend timestamp to avoid collision
21
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
22
- safe_name = f"{ts}_{filename}"
23
- path = os.path.join(self.base_dir, safe_name)
24
- with open(path, "wb") as f:
25
- f.write(file_bytes)
26
- return path
27
-
28
- def get_file(self, filename: str) -> bytes:
29
- path = os.path.join(self.base_dir, filename)
30
- if not os.path.exists(path):
31
- return None
32
- with open(path, "rb") as f:
33
- return f.read()
34
-
35
- class SwiftStorage(StorageProvider):
36
- """
37
- OpenStack Swift Storage Provider.
38
- Requires python-swiftclient installed.
39
- """
40
- def __init__(self, auth_url, username, password, project_name, container_name="elephmind_images"):
41
- # Import here to avoid error on Windows if not installed
42
- try:
43
- from swiftclient import Connection
44
- except ImportError:
45
- raise ImportError("python-swiftclient not installed!")
46
-
47
- self.container_name = container_name
48
- self.conn = Connection(
49
- authurl=auth_url,
50
- user=username,
51
- key=password,
52
- tenant_name=project_name,
53
- auth_version='3',
54
- os_options={'user_domain_name': 'Default', 'project_domain_name': 'Default'}
55
- )
56
- # Ensure container exists
57
- try:
58
- self.conn.put_container(self.container_name)
59
- except Exception as e:
60
- print(f"Swift Connection Error: {e}")
61
-
62
- def save_file(self, file_bytes: bytes, filename: str) -> str:
63
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
64
- safe_name = f"{ts}_{filename}"
65
- self.conn.put_object(
66
- self.container_name,
67
- safe_name,
68
- contents=file_bytes,
69
- content_type='application/octet-stream'
70
- )
71
- return f"swift://{self.container_name}/{safe_name}"
72
-
73
- def get_file(self, filename: str) -> bytes:
74
- # filename could be safe_name
75
- # logic to extract key if needed
76
- try:
77
- _, obj = self.conn.get_object(self.container_name, filename)
78
- return obj
79
- except Exception:
80
- return None
81
-
82
- # Factory
83
- def get_storage_provider(config_mode="LOCAL"):
84
- if config_mode == "OPENSTACK":
85
- return SwiftStorage(
86
- auth_url=os.getenv("OS_AUTH_URL"),
87
- username=os.getenv("OS_USERNAME"),
88
- password=os.getenv("OS_PASSWORD"),
89
- project_name=os.getenv("OS_PROJECT_NAME")
90
- )
91
- else:
92
- return LocalStorage()
 
1
+ import os
2
+ import abc
3
+ from datetime import datetime
4
+
5
+ class StorageProvider(abc.ABC):
6
+ @abc.abstractmethod
7
+ def save_file(self, file_bytes: bytes, filename: str) -> str:
8
+ pass
9
+
10
+ @abc.abstractmethod
11
+ def get_file(self, filename: str) -> bytes:
12
+ pass
13
+
14
+ class LocalStorage(StorageProvider):
15
+ def __init__(self, base_dir="data_storage"):
16
+ self.base_dir = base_dir
17
+ os.makedirs(base_dir, exist_ok=True)
18
+
19
+ def save_file(self, file_bytes: bytes, filename: str) -> str:
20
+ # Prepend timestamp to avoid collision
21
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
22
+ safe_name = f"{ts}_{filename}"
23
+ path = os.path.join(self.base_dir, safe_name)
24
+ with open(path, "wb") as f:
25
+ f.write(file_bytes)
26
+ return path
27
+
28
+ def get_file(self, filename: str) -> bytes:
29
+ path = os.path.join(self.base_dir, filename)
30
+ if not os.path.exists(path):
31
+ return None
32
+ with open(path, "rb") as f:
33
+ return f.read()
34
+
35
+ class SwiftStorage(StorageProvider):
36
+ """
37
+ OpenStack Swift Storage Provider.
38
+ Requires python-swiftclient installed.
39
+ """
40
+ def __init__(self, auth_url, username, password, project_name, container_name="elephmind_images"):
41
+ # Import here to avoid error on Windows if not installed
42
+ try:
43
+ from swiftclient import Connection
44
+ except ImportError:
45
+ raise ImportError("python-swiftclient not installed!")
46
+
47
+ self.container_name = container_name
48
+ self.conn = Connection(
49
+ authurl=auth_url,
50
+ user=username,
51
+ key=password,
52
+ tenant_name=project_name,
53
+ auth_version='3',
54
+ os_options={'user_domain_name': 'Default', 'project_domain_name': 'Default'}
55
+ )
56
+ # Ensure container exists
57
+ try:
58
+ self.conn.put_container(self.container_name)
59
+ except Exception as e:
60
+ print(f"Swift Connection Error: {e}")
61
+
62
+ def save_file(self, file_bytes: bytes, filename: str) -> str:
63
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
64
+ safe_name = f"{ts}_{filename}"
65
+ self.conn.put_object(
66
+ self.container_name,
67
+ safe_name,
68
+ contents=file_bytes,
69
+ content_type='application/octet-stream'
70
+ )
71
+ return f"swift://{self.container_name}/{safe_name}"
72
+
73
+ def get_file(self, filename: str) -> bytes:
74
+ # filename could be safe_name
75
+ # logic to extract key if needed
76
+ try:
77
+ _, obj = self.conn.get_object(self.container_name, filename)
78
+ return obj
79
+ except Exception:
80
+ return None
81
+
82
+ # Factory
83
+ def get_storage_provider(config_mode="LOCAL"):
84
+ if config_mode == "OPENSTACK":
85
+ return SwiftStorage(
86
+ auth_url=os.getenv("OS_AUTH_URL"),
87
+ username=os.getenv("OS_USERNAME"),
88
+ password=os.getenv("OS_PASSWORD"),
89
+ project_name=os.getenv("OS_PROJECT_NAME")
90
+ )
91
+ else:
92
+ return LocalStorage()
storage_manager.py CHANGED
@@ -1,85 +1,85 @@
1
- import os
2
- import uuid
3
- import logging
4
- from pathlib import Path
5
- from typing import Tuple, Optional
6
-
7
- # Configure Logging
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
-
11
- # Detect environment (Hugging Face Spaces vs Local)
12
- # HF Spaces with persistent storage usually mount at /data
13
- IS_HF_SPACE = os.path.exists('/data')
14
- if IS_HF_SPACE:
15
- BASE_STORAGE_DIR = Path('/data/storage')
16
- logger.info(f"Using PERSISTENT storage at {BASE_STORAGE_DIR}")
17
- else:
18
- BASE_STORAGE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) / "storage"
19
- logger.info(f"Using LOCAL storage at {BASE_STORAGE_DIR}")
20
-
21
- def get_user_storage_path(username: str) -> Path:
22
- """Get secure storage path for user, creating it if needed."""
23
- # Sanitize username to prevent directory traversal
24
- safe_username = "".join([c for c in username if c.isalnum() or c in ('-', '_')])
25
- user_path = BASE_STORAGE_DIR / safe_username
26
- user_path.mkdir(parents=True, exist_ok=True)
27
- return user_path
28
-
29
- def save_image(username: str, file_bytes: bytes, filename_hint: str = "image.png") -> str:
30
- """
31
- Save image to disk and return a unique image_id.
32
- Returns: image_id (e.g. IMG_ABC123)
33
- """
34
- # Generate ID
35
- unique_suffix = uuid.uuid4().hex[:12].upper()
36
- image_id = f"IMG_{unique_suffix}"
37
-
38
- # Determine extension
39
- ext = os.path.splitext(filename_hint)[1].lower()
40
- if not ext:
41
- ext = ".png" # Default
42
-
43
- filename = f"{image_id}{ext}"
44
- user_path = get_user_storage_path(username)
45
- file_path = user_path / filename
46
-
47
- try:
48
- with open(file_path, "wb") as f:
49
- f.write(file_bytes)
50
- logger.info(f"Saved image {image_id} for user {username} at {file_path}")
51
- return image_id
52
- except Exception as e:
53
- logger.error(f"Failed to save image: {e}")
54
- raise IOError(f"Storage Error: {e}")
55
-
56
- def load_image(username: str, image_id: str) -> Tuple[bytes, str]:
57
- """
58
- Load image bytes from disk.
59
- Returns: (file_bytes, file_path_str)
60
- """
61
- # Security: Ensure ID format is valid
62
- if not image_id.startswith("IMG_") or ".." in image_id or "/" in image_id:
63
- raise ValueError("Invalid image_id format")
64
-
65
- user_path = get_user_storage_path(username)
66
-
67
- # We don't know the extension, so look for the file
68
- # Or strict requirement: user must know?
69
- # Better: Search for matching file
70
- for file in user_path.glob(f"{image_id}.*"):
71
- try:
72
- with open(file, "rb") as f:
73
- return f.read(), str(file)
74
- except Exception as e:
75
- logger.error(f"Error reading file {file}: {e}")
76
- raise IOError("Read error")
77
-
78
- raise FileNotFoundError(f"Image {image_id} not found for user {username}")
79
-
80
- def get_image_absolute_path(username: str, image_id: str) -> Optional[str]:
81
- """Return absolute path if exists, else None."""
82
- user_path = get_user_storage_path(username)
83
- for file in user_path.glob(f"{image_id}.*"):
84
- return str(file)
85
- return None
 
1
+ import os
2
+ import uuid
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Tuple, Optional
6
+
7
+ # Configure Logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Detect environment (Hugging Face Spaces vs Local)
12
+ # HF Spaces with persistent storage usually mount at /data
13
+ IS_HF_SPACE = os.path.exists('/data')
14
+ if IS_HF_SPACE:
15
+ BASE_STORAGE_DIR = Path('/data/storage')
16
+ logger.info(f"Using PERSISTENT storage at {BASE_STORAGE_DIR}")
17
+ else:
18
+ BASE_STORAGE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) / "storage"
19
+ logger.info(f"Using LOCAL storage at {BASE_STORAGE_DIR}")
20
+
21
+ def get_user_storage_path(username: str) -> Path:
22
+ """Get secure storage path for user, creating it if needed."""
23
+ # Sanitize username to prevent directory traversal
24
+ safe_username = "".join([c for c in username if c.isalnum() or c in ('-', '_')])
25
+ user_path = BASE_STORAGE_DIR / safe_username
26
+ user_path.mkdir(parents=True, exist_ok=True)
27
+ return user_path
28
+
29
+ def save_image(username: str, file_bytes: bytes, filename_hint: str = "image.png") -> str:
30
+ """
31
+ Save image to disk and return a unique image_id.
32
+ Returns: image_id (e.g. IMG_ABC123)
33
+ """
34
+ # Generate ID
35
+ unique_suffix = uuid.uuid4().hex[:12].upper()
36
+ image_id = f"IMG_{unique_suffix}"
37
+
38
+ # Determine extension
39
+ ext = os.path.splitext(filename_hint)[1].lower()
40
+ if not ext:
41
+ ext = ".png" # Default
42
+
43
+ filename = f"{image_id}{ext}"
44
+ user_path = get_user_storage_path(username)
45
+ file_path = user_path / filename
46
+
47
+ try:
48
+ with open(file_path, "wb") as f:
49
+ f.write(file_bytes)
50
+ logger.info(f"Saved image {image_id} for user {username} at {file_path}")
51
+ return image_id
52
+ except Exception as e:
53
+ logger.error(f"Failed to save image: {e}")
54
+ raise IOError(f"Storage Error: {e}")
55
+
56
+ def load_image(username: str, image_id: str) -> Tuple[bytes, str]:
57
+ """
58
+ Load image bytes from disk.
59
+ Returns: (file_bytes, file_path_str)
60
+ """
61
+ # Security: Ensure ID format is valid
62
+ if not image_id.startswith("IMG_") or ".." in image_id or "/" in image_id:
63
+ raise ValueError("Invalid image_id format")
64
+
65
+ user_path = get_user_storage_path(username)
66
+
67
+ # We don't know the extension, so look for the file
68
+ # Or strict requirement: user must know?
69
+ # Better: Search for matching file
70
+ for file in user_path.glob(f"{image_id}.*"):
71
+ try:
72
+ with open(file, "rb") as f:
73
+ return f.read(), str(file)
74
+ except Exception as e:
75
+ logger.error(f"Error reading file {file}: {e}")
76
+ raise IOError("Read error")
77
+
78
+ raise FileNotFoundError(f"Image {image_id} not found for user {username}")
79
+
80
+ def get_image_absolute_path(username: str, image_id: str) -> Optional[str]:
81
+ """Return absolute path if exists, else None."""
82
+ user_path = get_user_storage_path(username)
83
+ for file in user_path.glob(f"{image_id}.*"):
84
+ return str(file)
85
+ return None
upload_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # upload_model.py - Upload model to Hugging Face Hub
2
+ from huggingface_hub import upload_folder
3
+ import os
4
+
5
+ model_path = os.path.join("models", "oeil d'elephant")
6
+ print(f"Uploading from: {model_path}")
7
+ print(f"Path exists: {os.path.exists(model_path)}")
8
+
9
+ if os.path.exists(model_path):
10
+ print("Starting upload... (this may take a while for 3.5GB)")
11
+ upload_folder(
12
+ folder_path=model_path,
13
+ repo_id="issoufzousko07/medsigclip-model",
14
+ repo_type="model"
15
+ )
16
+ print("Upload complete!")
17
+ else:
18
+ print(f"ERROR: Path not found: {model_path}")
19
+ print("Available in models/:")
20
+ if os.path.exists("models"):
21
+ print(os.listdir("models"))
upload_space.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # upload_space.py - Upload code to HuggingFace Space (excluding large model files)
2
+ from huggingface_hub import upload_folder
3
+ import os
4
+
5
+ print("Uploading ElephMind API to HuggingFace Space...")
6
+ print("(Model will be downloaded from Hub at runtime)")
7
+
8
+ # Determine the server directory (where this script lives)
9
+ server_dir = os.path.dirname(os.path.abspath(__file__))
10
+
11
+ print(f"Uploading directory: {server_dir}")
12
+
13
+ upload_folder(
14
+ folder_path=server_dir,
15
+ repo_id="issoufzousko07/elephmind-api",
16
+ repo_type="space",
17
+ ignore_patterns=["models/*", "*.pyc", "__pycache__", "*.db", "storage/*", "data_storage/*", ".env", "venv", ".git", ".idea"]
18
+ )
19
+
20
+ print("[OK] Upload complete!")
21
+ print("Your Space should start building at: https://huggingface.co/spaces/issoufzousko07/elephmind-api")