Enterwar99 commited on
Commit
b0441fb
verified
1 Parent(s): f7c13c9

Update api_app.py

Browse files
Files changed (1) hide show
  1. api_app.py +185 -233
api_app.py CHANGED
@@ -8,30 +8,27 @@ import torch.nn as nn
8
  import io
9
  import numpy as np
10
  import os
11
- from typing import List, Dict, Any, Optional # Dodano Optional
12
- import logging # Dodajemy import modu艂u logging
13
- import cv2 # Dodajemy OpenCV
14
- import base64 # Dodajemy base64
15
 
16
- # Importy dla Grad-CAM
17
  from pytorch_grad_cam import GradCAMPlusPlus
18
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
19
- from huggingface_hub import hf_hub_download # Do pobierania modelu z Huba
20
- from pydantic import BaseModel # Dodano Pydantic
21
 
22
  # --- Konfiguracja Logowania ---
23
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
24
  logger = logging.getLogger(__name__)
25
 
26
  # --- Konfiguracja ---
27
- # Upewnij si臋, 偶e te warto艣ci s膮 zgodne z Twoim repozytorium modelu
28
  HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII"
29
- MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium
30
 
31
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
33
  IMAGENET_STD = [0.229, 0.224, 0.225]
34
- # Progi b臋d膮 teraz przekazywane jako parametry zapytania
35
 
36
  # Globalne zmienne dla modelu i transformacji
37
  model_instance = None
@@ -51,174 +48,154 @@ def initialize_model():
51
  if model_instance is not None:
52
  return
53
 
54
- logger.info(f"Rozpoczynanie inicjalizacji modelu...")
55
- logger.info(f"Pobieranie pliku modelu '{MODEL_FILENAME}' z repozytorium '{HF_MODEL_REPO_ID}'...")
56
  try:
57
- # Odczytaj token z sekret贸w, je艣li jest dost臋pny
58
- # Nazwa zmiennej 艣rodowiskowej musi by膰 taka sama jak nazwa sekretu w ustawieniach Space
59
  hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ")
60
- if hf_auth_token:
61
- logger.info("U偶ywam tokenu HF_TOKEN_MODEL_READ do pobrania modelu.")
62
- else:
63
- logger.warning("Sekret HF_TOKEN_MODEL_READ nie zosta艂 znaleziony. Pr贸ba pobrania modelu bez tokenu.")
64
-
65
  model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token)
66
  logger.info(f"Plik modelu pomy艣lnie pobrany do: {model_pt_path}")
67
  except Exception as e:
68
  logger.error(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}", exc_info=True)
69
  raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}")
70
 
71
- logger.info(f"Inicjalizacja architektury modelu ResNet-18...")
72
  model_arch = models.resnet18(weights=None)
73
  num_feats = model_arch.fc.in_features
74
- model_arch.fc = nn.Sequential(
75
- nn.Dropout(0.5),
76
- nn.Linear(num_feats, 5)
77
- )
78
- logger.info(f"Architektura modelu ResNet-18 zainicjalizowana.")
79
-
80
- logger.info(f"艁adowanie wag modelu z {model_pt_path}...")
81
  model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE))
82
  model_arch.to(DEVICE)
83
  model_arch.eval()
84
  model_instance = model_arch
85
- logger.info(f"Wagi modelu za艂adowane. Model przeniesiony na urz膮dzenie: {DEVICE} i ustawiony w tryb eval().")
86
  transform_pipeline = transforms.Compose([
87
  transforms.Resize((224, 224)),
88
  transforms.ToTensor(),
89
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) # Upewnij si臋, 偶e IMAGENET_MEAN i IMAGENET_STD s膮 zdefiniowane
90
  ])
91
  logger.info(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
92
 
93
  # --- Funkcja do predykcji z kwantyfikacj膮 niepewno艣ci (MC Dropout) ---
94
- def predict_with_mc_dropout(current_model_instance, input_tensor_on_device, mc_dropout_samples: int, uncertainty_threshold_std: float): # Dodano parametry
95
- """
96
- Wykonuje predykcj臋 z u偶yciem Monte Carlo Dropout do oszacowania niepewno艣ci.
97
- """
98
- logger.info(f"Performing MC Dropout with {mc_dropout_samples} samples. Uncertainty threshold (std): {uncertainty_threshold_std}") # U偶yj parametr贸w
99
 
100
  original_mode_is_training = current_model_instance.training
101
- current_model_instance.train() # W艂膮cz warstwy dropout
 
 
 
 
 
102
 
103
- all_probs_list = []
104
- with torch.no_grad(): # Gradienty nie s膮 potrzebne do samego przej艣cia w prz贸d
105
- for _ in range(mc_dropout_samples): # Poprawiona nazwa zmiennej
106
- output = current_model_instance(input_tensor_on_device)
107
  probs_tensor = torch.nn.functional.softmax(output, dim=1)
108
- all_probs_list.append(probs_tensor.cpu().numpy())
109
 
110
- # Przywr贸膰 oryginalny tryb modelu
111
  if not original_mode_is_training:
112
  current_model_instance.eval()
113
 
114
- predictions_stack = np.vstack(all_probs_list) # Kszta艂t: (n_samples, num_classes)
115
- mean_probabilities = np.mean(predictions_stack, axis=0) # Kszta艂t: (num_classes,)
116
- std_dev_probabilities = np.std(predictions_stack, axis=0) # Kszta艂t: (num_classes,)
117
-
118
- predicted_class_index = np.argmax(mean_probabilities)
119
- confidence_in_predicted_class = float(mean_probabilities[predicted_class_index]) # Jawna konwersja do float
120
 
121
- # U偶yj 艣redniej odchyle艅 standardowych prawdopodobie艅stw wszystkich klas jako metryki niepewno艣ci
122
- uncertainty_metric = np.mean(std_dev_probabilities)
123
-
124
- is_uncertain = uncertainty_metric > uncertainty_threshold_std # U偶yj uncertainty_threshold_std
125
- logger.info(f"MC Dropout Results: Predicted Index: {int(predicted_class_index)}, Confidence: {confidence_in_predicted_class:.4f}, Uncertainty (avg_std): {uncertainty_metric:.4f}, Is Uncertain: {is_uncertain}")
126
-
127
- birads_category_if_confident = int(predicted_class_index) + 1 # Jawna konwersja do int
128
- base_result = {
129
- "birads": int(birads_category_if_confident) if not is_uncertain else None, # Jawna konwersja do int
130
- "confidence": float(confidence_in_predicted_class) if not is_uncertain else None,
131
- "interpretation": interpretations_dict.get(birads_category_if_confident, "Nieznana klasyfikacja") if not is_uncertain \
132
- else f"Model jest niepewny co do tego obrazu (niepewno艣膰: {uncertainty_metric:.4f}). Sprawd藕 jako艣膰 i typ obrazu. Mo偶e to by膰 obraz spoza domeny medycznej.",
133
- "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))},
134
- "grad_cam_image_base64": None, # Zostanie wype艂nione p贸藕niej, je艣li pewne
135
- "error": "High prediction uncertainty" if is_uncertain else None,
136
- "details": f"Uncertainty metric ({uncertainty_metric:.4f}) {'przekroczy艂a' if is_uncertain else 'jest w granicach'} progu ({uncertainty_threshold_std})." # U偶yj uncertainty_threshold_std
137
- }
138
- return base_result, predicted_class_index # Zwr贸膰 r贸wnie偶 indeks dla Grad-CAM
139
-
140
- # --- Funkcja do tworzenia obrazu z na艂o偶on膮 map膮 Grad-CAM (zaadaptowana z app.py) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def create_grad_cam_overlay_image(original_pil_image: Image.Image, grayscale_cam: np.ndarray, birads_category: int, transparency: float = 0.5) -> Image.Image:
142
- """Tworzy obraz PIL z na艂o偶on膮 map膮 Grad-CAM."""
143
- logger.info(f"Rozpoczynanie tworzenia obrazu Grad-CAM overlay dla kategorii {birads_category}")
144
  try:
145
  img_np = np.array(original_pil_image.convert('RGB')).astype(np.float32) / 255.0
146
  cam_resized = cv2.resize(grayscale_cam, (img_np.shape[1], img_np.shape[0]))
147
-
148
  cam_normalized = (cam_resized - np.min(cam_resized)) / (np.max(cam_resized) - np.min(cam_resized) + 1e-8)
149
-
150
  threshold = 0.7
151
  cam_normalized[cam_normalized < threshold] = 0
152
-
153
- kernel_size = 5
154
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
155
  cam_cleaned = cv2.morphologyEx(cam_normalized, cv2.MORPH_OPEN, kernel)
156
-
157
  birads_colors_rgb = {
158
- 1: (0.1, 0.7, 0.1),
159
- 2: (0.53, 0.81, 0.92),
160
- 3: (1.0, 0.9, 0.0),
161
- 4: (1.0, 0.5, 0.0),
162
- 5: (0.9, 0.1, 0.1)
163
  }
164
  chosen_color = np.array(birads_colors_rgb.get(birads_category, (0.5, 0.5, 0.5)))
165
-
166
  color_overlay_np = np.zeros_like(img_np)
167
- for c in range(3):
168
- color_overlay_np[:, :, c] = chosen_color[c]
169
-
170
  alpha = cam_cleaned * transparency
171
  alpha_expanded = alpha[..., np.newaxis]
172
-
173
  highlighted_image_np = img_np * (1 - alpha_expanded) + color_overlay_np * alpha_expanded
174
  highlighted_image_np = np.clip(highlighted_image_np, 0, 1)
175
  final_image_np = (highlighted_image_np * 255).astype(np.uint8)
176
- logger.info("Obraz Grad-CAM overlay pomy艣lnie utworzony.")
177
  return Image.fromarray(final_image_np)
178
  except Exception as e:
179
  logger.error(f"B艂膮d podczas tworzenia obrazu Grad-CAM overlay: {e}", exc_info=True)
180
  return None
181
 
182
- # --- Funkcja do heurystycznych test贸w OOD ---
183
- def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str, colorfulness_threshold: float, uniformity_threshold: float, aspect_ratio_min: float, aspect_ratio_max: float) -> bool: # Dodano parametry
184
  """
185
- Wykonuje zestaw prostych heurystyk do wykrywania obraz贸w spoza dystrybucji.
186
  """
187
  logger.info(f"[RequestID: {request_id}] Uruchamianie heurystycznych test贸w OOD...")
188
  width, height = pil_image.size
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- # Heurystyka 1: Sprawdzenie proporcji obrazu
191
  aspect_ratio = width / height
192
- if not (aspect_ratio_min < aspect_ratio < aspect_ratio_max): # U偶yj parametr贸w
193
- logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Nietypowe proporcje obrazu: {aspect_ratio:.2f} (min: {aspect_ratio_min}, max: {aspect_ratio_max}). Odrzucam.")
194
- return False
195
-
196
- # Heurystyka 2: Analiza "kolorowo艣ci" (dla obraz贸w RGB)
197
- # Mammografie s膮 w skali szaro艣ci; je艣li obraz jest kolorowy, ta metryka b臋dzie wysoka.
198
- # Zak艂adamy, 偶e pil_image mo偶e by膰 ju偶 w trybie RGB lub zostanie do niego skonwertowany.
199
- # Je艣li pil_image jest w trybie 'L', mo偶na by t臋 heurystyk臋 pomin膮膰 lub dostosowa膰.
200
- # Dla obrazu w skali szaro艣ci skonwertowanego do RGB, R=G=B, wi臋c std_per_pixel_across_channels b臋dzie bliskie 0.
201
- img_rgb_for_color_check = pil_image.convert('RGB') # Upewnijmy si臋, 偶e pracujemy na RGB dla tej heurystyki
202
- img_np_rgb = np.array(img_rgb_for_color_check)
203
- std_per_pixel_across_channels = np.std(img_np_rgb, axis=2) # Odch. std. dla ka偶dego piksela po kana艂ach R,G,B
204
- mean_std_across_channels = np.mean(std_per_pixel_across_channels)
205
- if mean_std_across_channels > colorfulness_threshold: # U偶yj parametru
206
- logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Obraz wydaje si臋 zbyt kolorowy. 艢rednie odch. std. mi臋dzy kana艂ami: {mean_std_across_channels:.2f}. Odrzucam.")
207
- return False
208
-
209
- # Heurystyka 3: Sprawdzenie, czy obraz nie jest prawie jednolity (dominuj膮ca jasno艣膰/ciemno艣膰)
210
- # Konwertujemy do skali szaro艣ci dla tej analizy
211
  gray_image = pil_image.convert('L')
212
  std_dev_intensity = np.std(np.array(gray_image))
213
- if std_dev_intensity < uniformity_threshold: # U偶yj parametru
214
- logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Obraz wydaje si臋 zbyt jednolity (ma艂o zr贸偶nicowania jasno艣ci). Odch. std. intensywno艣ci: {std_dev_intensity:.2f}. Odrzucam.")
215
- return False
216
-
217
- logger.info(f"[RequestID: {request_id}] Heurystyczne testy OOD zako艅czone pomy艣lnie. Obraz wygl膮da na potencjalnie poprawny.")
218
- return True
219
 
220
  # --- Aplikacja FastAPI ---
221
- # --- Definicja modelu odpowiedzi Pydantic ---
222
  class PredictionResult(BaseModel):
223
  birads: Optional[int] = None
224
  confidence: Optional[float] = None
@@ -228,19 +205,17 @@ class PredictionResult(BaseModel):
228
  error: Optional[str] = None
229
  details: Optional[str] = None
230
 
231
-
232
  app = FastAPI(title="BI-RADS Mammography Classification API")
233
 
234
  @app.on_event("startup")
235
  async def startup_event():
236
- """Wywo艂ywane przy starcie aplikacji FastAPI."""
237
  logger.info("Rozpoczynanie eventu startup aplikacji FastAPI.")
238
  initialize_model()
239
 
240
- @app.post("/predict/", response_model=List[PredictionResult]) # U偶ycie modelu Pydantic
241
- # Dodano parametry zapytania z warto艣ciami domy艣lnymi
242
- async def predict_image(
243
- file: UploadFile = File(...),
244
  colorfulness_threshold: float = 15.0,
245
  uniformity_threshold: float = 10.0,
246
  aspect_ratio_min: float = 0.4,
@@ -248,124 +223,101 @@ async def predict_image(
248
  mc_dropout_samples: int = 25,
249
  uncertainty_threshold_std: float = 0.08
250
  ):
251
- """
252
- Endpoint do klasyfikacji obrazu mammograficznego.
253
- """
254
- request_id = os.urandom(8).hex() # Prosty identyfikator 偶膮dania
255
- logger.info(f"[RequestID: {request_id}] Otrzymano 偶膮danie /predict/")
256
- logger.info(f"[RequestID: {request_id}] Received parameters: colorfulness_threshold={colorfulness_threshold}, uniformity_threshold={uniformity_threshold}, aspect_ratio_min={aspect_ratio_min}, aspect_ratio_max={aspect_ratio_max}, mc_dropout_samples={mc_dropout_samples}, uncertainty_threshold_std={uncertainty_threshold_std}")
257
 
258
  if model_instance is None or transform_pipeline is None:
259
- logger.error(f"[RequestID: {request_id}] Model nie jest zainicjalizowany podczas 偶膮dania /predict/.")
260
- raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany. Spr贸buj ponownie za chwil臋.")
261
-
262
- try:
263
- logger.info(f"[RequestID: {request_id}] Odczytywanie i przetwarzanie wgranego pliku...")
264
- contents = await file.read()
265
- # Wczytaj obraz, ale jeszcze nie konwertuj do RGB, aby heurystyki mog艂y dzia艂a膰 na bardziej "surowych" danych
266
- image_pil_original = Image.open(io.BytesIO(contents))
267
- logger.info(f"[RequestID: {request_id}] Plik obrazu pomy艣lnie odczytany. Oryginalny tryb: {image_pil_original.mode}, Rozmiar: {image_pil_original.size}")
268
-
269
- # --- Etap: Heurystyczne testy OOD (PRE-FILTR) ---
270
- # U偶ywamy .copy(), aby unikn膮膰 potencjalnych modyfikacji oryginalnego obiektu image_pil_original
271
- # przez funkcj臋 run_heuristic_ood_checks, je艣li by takie wykonywa艂a (np. konwersje inplace).
272
- # Przekazujemy parametry do funkcji
273
- if not run_heuristic_ood_checks(
274
- image_pil_original.copy(), request_id,
275
- colorfulness_threshold, uniformity_threshold, aspect_ratio_min, aspect_ratio_max
276
- ):
277
- # Heurystyki wykry艂y problem
278
- return JSONResponse(
279
- status_code=400, # Bad Request
280
- content=[
281
- {
282
- "interpretation": "Obraz odrzucony przez wst臋pne testy heurystyczne. Nie wygl膮da na poprawny obraz medyczny.",
283
- "class_probabilities": {}, # Puste, bo nie by艂o predykcji klas
284
- "error": "Image does not appear to be a valid medical mammogram based on initial checks.",
285
- "details": "Heurystyczne testy OOD nie powiod艂y si臋. Obraz odrzucony przed analiz膮 przez model AI.",
286
- "birads": None,
287
- "confidence": None,
288
- "grad_cam_image_base64": None
289
- }
290
- ]
291
  )
292
- image = image_pil_original.convert("RGB") # Teraz konwertuj do RGB dla modelu g艂贸wnego
293
- except Exception as e:
294
- logger.error(f"[RequestID: {request_id}] B艂膮d podczas odczytu pliku obrazu: {e}", exc_info=True)
295
- raise HTTPException(status_code=400, detail=f"Nie mo偶na odczyta膰 pliku obrazu: {e}")
296
-
297
- # --- Preprocessing ---
298
- logger.info(f"[RequestID: {request_id}] Rozpoczynanie preprocessingu obrazu...")
299
- input_tensor = transform_pipeline(image).unsqueeze(0).to(DEVICE)
300
- logger.info(f"[RequestID: {request_id}] Preprocessing zako艅czony. Kszta艂t tensora wej艣ciowego: {input_tensor.shape}")
301
-
302
- # --- Etap: Predykcja z kwantyfikacj膮 niepewno艣ci (MC Dropout) ---
303
- # Upewnij si臋, 偶e model_instance, DEVICE, interpretations_dict s膮 dost臋pne
304
- mc_output_dict, predicted_idx_from_mc = predict_with_mc_dropout(
305
- model_instance,
306
- input_tensor,
307
- mc_dropout_samples, uncertainty_threshold_std # Przekazujemy parametry
308
- )
309
-
310
- results = []
311
- if mc_output_dict.get("error") == "High prediction uncertainty":
312
- logger.warning(f"[RequestID: {request_id}] Wysoka niepewno艣膰 predykcji: {mc_output_dict.get('details')}")
313
- results.append(mc_output_dict) # Dodaj wynik z informacj膮 o niepewno艣ci
314
- return JSONResponse(content=results) # API oczekuje listy
315
-
316
- # --- Je艣li predykcja jest pewna, kontynuuj z Grad-CAM ---
317
- # mc_output_dict zawiera ju偶 'birads', 'confidence', 'interpretation', 'class_probabilities'
318
- birads_category_for_cam = mc_output_dict["birads"] # To jest ju偶 BI-RADS 1-5
319
- # predicted_idx_from_mc to indeks 0-4
320
-
321
- # Upewnij si臋, 偶e model jest w trybie ewaluacji dla Grad-CAM
322
- # Funkcja MC Dropout powinna go przywr贸ci膰, ale dla pewno艣ci:
323
- model_instance.eval()
324
-
325
- # Generowanie Grad-CAM
326
- grad_cam_map_serialized = None # Zmienna zdefiniowana przed blokiem try
327
- grad_cam_image_base64 = None # Zmienna zdefiniowana przed blokiem try
328
- # Poprawka: U偶yj birads_category_for_cam zamiast niezdefiniowanej birads_category
329
- logger.info(f"[RequestID: {request_id}] Rozpoczynanie generowania Grad-CAM dla kategorii {birads_category_for_cam}...")
330
- try:
331
- # model_instance is already in eval mode from initialize_model()
 
 
332
  target_layers = [model_instance.layer4[-1]]
333
  cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers)
334
 
335
- # Tensor wej艣ciowy dla CAM musi mie膰 requires_grad=True
336
- current_input_tensor_for_cam = input_tensor.clone().detach().requires_grad_(True)
337
- targets_for_cam = [ClassifierOutputTarget(predicted_idx_from_mc)] # U偶yj indeksu z MC Dropout
338
-
339
- grayscale_cam = cam_algorithm(input_tensor=current_input_tensor_for_cam, targets=targets_for_cam)
340
- if grayscale_cam is not None:
341
- grad_cam_map_np = grayscale_cam[0, :]
342
- logger.info(f"[RequestID: {request_id}] Grad-CAM wygenerowany pomy艣lnie.")
343
-
344
- # Tworzenie obrazu z na艂o偶on膮 map膮 Grad-CAM
345
- overlay_image_pil = create_grad_cam_overlay_image(original_pil_image=image, # oryginalny obraz PIL
346
- grayscale_cam=grad_cam_map_np,
347
- birads_category=birads_category_for_cam)
348
- if overlay_image_pil:
349
- buffered = io.BytesIO()
350
- overlay_image_pil.save(buffered, format="PNG") # Zapisz jako PNG
351
- grad_cam_image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
352
- logger.info(f"[RequestID: {request_id}] Obraz Grad-CAM overlay zakodowany do base64.")
353
- else:
354
- logger.warning(f"[RequestID: {request_id}] Nie uda艂o si臋 utworzy膰 obrazu Grad-CAM overlay.")
355
- else:
356
- logger.warning(f"[RequestID: {request_id}] Wygenerowany Grad-CAM jest None.")
357
- except Exception as e:
358
- logger.error(f"[RequestID: {request_id}] B艂膮d podczas generowania Grad-CAM w API: {e}", exc_info=True)
359
- # grad_cam_image_base64 pozostanie None
360
-
361
- mc_output_dict["grad_cam_image_base64"] = grad_cam_image_base64 # Zaktualizuj s艂ownik wynikowy
362
- results.append(mc_output_dict) # Dodaj finalny wynik (pewny, z Grad-CAM lub bez)
363
- logger.info(f"[RequestID: {request_id}] Przetwarzanie 偶膮dania /predict/ zako艅czone. Zwracam wyniki.")
364
- return JSONResponse(content=results)
365
 
366
  @app.get("/")
367
  async def root():
368
  logger.info("Otrzymano 偶膮danie GET na /")
369
- return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."}
370
-
371
- # Do uruchomienia lokalnie: uvicorn api_app:app --reload
 
8
  import io
9
  import numpy as np
10
  import os
11
+ from typing import List, Dict, Any, Optional
12
+ import logging
13
+ import cv2
14
+ import base64
15
 
 
16
  from pytorch_grad_cam import GradCAMPlusPlus
17
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
18
+ from huggingface_hub import hf_hub_download
19
+ from pydantic import BaseModel
20
 
21
  # --- Konfiguracja Logowania ---
22
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
  logger = logging.getLogger(__name__)
24
 
25
  # --- Konfiguracja ---
 
26
  HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII"
27
+ MODEL_FILENAME = "best_model.pth"
28
 
29
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
31
  IMAGENET_STD = [0.229, 0.224, 0.225]
 
32
 
33
  # Globalne zmienne dla modelu i transformacji
34
  model_instance = None
 
48
  if model_instance is not None:
49
  return
50
 
51
+ logger.info("Rozpoczynanie inicjalizacji modelu...")
 
52
  try:
 
 
53
  hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ")
 
 
 
 
 
54
  model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token)
55
  logger.info(f"Plik modelu pomy艣lnie pobrany do: {model_pt_path}")
56
  except Exception as e:
57
  logger.error(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}", exc_info=True)
58
  raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}")
59
 
 
60
  model_arch = models.resnet18(weights=None)
61
  num_feats = model_arch.fc.in_features
62
+ model_arch.fc = nn.Sequential(nn.Dropout(0.5), nn.Linear(num_feats, 5))
63
+
 
 
 
 
 
64
  model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE))
65
  model_arch.to(DEVICE)
66
  model_arch.eval()
67
  model_instance = model_arch
68
+
69
  transform_pipeline = transforms.Compose([
70
  transforms.Resize((224, 224)),
71
  transforms.ToTensor(),
72
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
73
  ])
74
  logger.info(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
75
 
76
  # --- Funkcja do predykcji z kwantyfikacj膮 niepewno艣ci (MC Dropout) ---
77
+ def predict_with_mc_dropout(current_model_instance, batch_tensor_on_device, mc_dropout_samples: int, uncertainty_threshold_std: float):
78
+ logger.info(f"Performing MC Dropout on a batch of size {batch_tensor_on_device.shape[0]} with {mc_dropout_samples} samples.")
 
 
 
79
 
80
  original_mode_is_training = current_model_instance.training
81
+ current_model_instance.train()
82
+
83
+ batch_size = batch_tensor_on_device.shape[0]
84
+ num_classes = 5
85
+
86
+ all_probs_batch = np.zeros((batch_size, mc_dropout_samples, num_classes))
87
 
88
+ with torch.no_grad():
89
+ for i in range(mc_dropout_samples):
90
+ output = current_model_instance(batch_tensor_on_device)
 
91
  probs_tensor = torch.nn.functional.softmax(output, dim=1)
92
+ all_probs_batch[:, i, :] = probs_tensor.cpu().numpy()
93
 
 
94
  if not original_mode_is_training:
95
  current_model_instance.eval()
96
 
97
+ mean_probabilities_batch = np.mean(all_probs_batch, axis=1)
98
+ std_dev_probabilities_batch = np.std(all_probs_batch, axis=1)
 
 
 
 
99
 
100
+ results = []
101
+ for i in range(batch_size):
102
+ mean_probabilities = mean_probabilities_batch[i]
103
+ std_dev_probabilities = std_dev_probabilities_batch[i]
104
+
105
+ predicted_class_index = np.argmax(mean_probabilities)
106
+ confidence_in_predicted_class = float(np.max(all_probs_batch[i, :, predicted_class_index]))
107
+ uncertainty_metric = np.mean(std_dev_probabilities)
108
+ is_uncertain = uncertainty_metric > uncertainty_threshold_std
109
+
110
+ logger.info(f"MC Dropout Results for image {i}: Predicted Index: {int(predicted_class_index)}, Confidence (MaxProb): {confidence_in_predicted_class:.4f}, Uncertainty (avg_std): {uncertainty_metric:.4f}, Is Uncertain: {is_uncertain}")
111
+
112
+ birads_category_if_confident = int(predicted_class_index) + 1
113
+
114
+ if is_uncertain:
115
+ result = {
116
+ "birads": None, "confidence": None,
117
+ "interpretation": f"Model jest niepewny co do tego obrazu (niepewno艣膰: {uncertainty_metric:.4f}). Sprawd藕 jako艣膰 i typ obrazu.",
118
+ "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))},
119
+ "grad_cam_image_base64": None, "error": "High prediction uncertainty",
120
+ "details": f"Uncertainty metric ({uncertainty_metric:.4f}) przekroczy艂a pr贸g ({uncertainty_threshold_std})."
121
+ }
122
+ else:
123
+ result = {
124
+ "birads": birads_category_if_confident,
125
+ "confidence": confidence_in_predicted_class,
126
+ "interpretation": interpretations_dict.get(birads_category_if_confident, "Nieznana klasyfikacja"),
127
+ "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))},
128
+ "grad_cam_image_base64": None, "error": None,
129
+ "details": f"Uncertainty metric ({uncertainty_metric:.4f}) jest w granicach progu ({uncertainty_threshold_std}).",
130
+ "predicted_class_index": predicted_class_index
131
+ }
132
+ results.append(result)
133
+
134
+ return results
135
+
136
+ # --- Funkcja do tworzenia obrazu z na艂o偶on膮 map膮 Grad-CAM ---
137
  def create_grad_cam_overlay_image(original_pil_image: Image.Image, grayscale_cam: np.ndarray, birads_category: int, transparency: float = 0.5) -> Image.Image:
 
 
138
  try:
139
  img_np = np.array(original_pil_image.convert('RGB')).astype(np.float32) / 255.0
140
  cam_resized = cv2.resize(grayscale_cam, (img_np.shape[1], img_np.shape[0]))
 
141
  cam_normalized = (cam_resized - np.min(cam_resized)) / (np.max(cam_resized) - np.min(cam_resized) + 1e-8)
 
142
  threshold = 0.7
143
  cam_normalized[cam_normalized < threshold] = 0
144
+ kernel = np.ones((5, 5), np.uint8)
 
 
145
  cam_cleaned = cv2.morphologyEx(cam_normalized, cv2.MORPH_OPEN, kernel)
 
146
  birads_colors_rgb = {
147
+ 1: (0.1, 0.7, 0.1), 2: (0.53, 0.81, 0.92), 3: (1.0, 0.9, 0.0),
148
+ 4: (1.0, 0.5, 0.0), 5: (0.9, 0.1, 0.1)
 
 
 
149
  }
150
  chosen_color = np.array(birads_colors_rgb.get(birads_category, (0.5, 0.5, 0.5)))
 
151
  color_overlay_np = np.zeros_like(img_np)
152
+ for c in range(3): color_overlay_np[:, :, c] = chosen_color[c]
 
 
153
  alpha = cam_cleaned * transparency
154
  alpha_expanded = alpha[..., np.newaxis]
 
155
  highlighted_image_np = img_np * (1 - alpha_expanded) + color_overlay_np * alpha_expanded
156
  highlighted_image_np = np.clip(highlighted_image_np, 0, 1)
157
  final_image_np = (highlighted_image_np * 255).astype(np.uint8)
 
158
  return Image.fromarray(final_image_np)
159
  except Exception as e:
160
  logger.error(f"B艂膮d podczas tworzenia obrazu Grad-CAM overlay: {e}", exc_info=True)
161
  return None
162
 
163
+ # --- ZAKTUALIZOWANA Funkcja do heurystycznych test贸w OOD ---
164
+ def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str, colorfulness_threshold: float, uniformity_threshold: float, aspect_ratio_min: float, aspect_ratio_max: float) -> Optional[str]:
165
  """
166
+ Wykonuje heurystyki OOD. Zwraca konkretny komunikat b艂臋du w razie problemu, w przeciwnym razie None.
167
  """
168
  logger.info(f"[RequestID: {request_id}] Uruchamianie heurystycznych test贸w OOD...")
169
  width, height = pil_image.size
170
+
171
+ # Sprawdzimy najpierw kolorowo艣膰, bo to najcz臋stszy problem
172
+ img_rgb_for_color_check = pil_image.convert('RGB')
173
+ img_np_rgb = np.array(img_rgb_for_color_check)
174
+ mean_std_across_channels = np.mean(np.std(img_np_rgb, axis=2))
175
+ logger.info(f"[RequestID: {request_id}] Heurystyka: Kolorowo艣膰 = {mean_std_across_channels:.2f} (pr贸g: {colorfulness_threshold})")
176
+
177
+ if mean_std_across_channels > colorfulness_threshold:
178
+ # Ten komunikat jest teraz bardziej specyficzny
179
+ msg = f"Wykryto kolorowy obraz (wska藕nik: {mean_std_across_channels:.2f}). System oczekuje obrazu w skali szaro艣ci, typowego dla bada艅 medycznych."
180
+ logger.warning(f"[RequestID: {request_id}] Heurystyka OOD ODRZUCONA: {msg}")
181
+ # Zwracamy specjalny typ b艂臋du, kt贸ry potem rozpoznamy
182
+ return f"INVALID_IMAGE_TYPE: {msg}"
183
 
 
184
  aspect_ratio = width / height
185
+ if not (aspect_ratio_min < aspect_ratio < aspect_ratio_max):
186
+ msg = f"Nietypowe proporcje obrazu: {aspect_ratio:.2f}."
187
+ return f"HEURISTIC_FAILED: {msg}"
188
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  gray_image = pil_image.convert('L')
190
  std_dev_intensity = np.std(np.array(gray_image))
191
+ if std_dev_intensity < uniformity_threshold:
192
+ msg = f"Obraz wydaje si臋 zbyt jednolity (np. ca艂y czarny): {std_dev_intensity:.2f}."
193
+ return f"HEURISTIC_FAILED: {msg}"
194
+
195
+ logger.info(f"[RequestID: {request_id}] Heurystyczne testy OOD zako艅czone pomy艣lnie.")
196
+ return None
197
 
198
  # --- Aplikacja FastAPI ---
 
199
  class PredictionResult(BaseModel):
200
  birads: Optional[int] = None
201
  confidence: Optional[float] = None
 
205
  error: Optional[str] = None
206
  details: Optional[str] = None
207
 
 
208
  app = FastAPI(title="BI-RADS Mammography Classification API")
209
 
210
  @app.on_event("startup")
211
  async def startup_event():
 
212
  logger.info("Rozpoczynanie eventu startup aplikacji FastAPI.")
213
  initialize_model()
214
 
215
+ # --- ZAKTUALIZOWANY Endpoint /predict/ ---
216
+ @app.post("/predict/", response_model=List[PredictionResult])
217
+ async def predict_images(
218
+ files: List[UploadFile] = File(...),
219
  colorfulness_threshold: float = 15.0,
220
  uniformity_threshold: float = 10.0,
221
  aspect_ratio_min: float = 0.4,
 
223
  mc_dropout_samples: int = 25,
224
  uncertainty_threshold_std: float = 0.08
225
  ):
226
+ request_id = os.urandom(8).hex()
227
+ logger.info(f"[RequestID: {request_id}] Otrzymano 偶膮danie /predict/ dla {len(files)} plik贸w.")
 
 
 
 
228
 
229
  if model_instance is None or transform_pipeline is None:
230
+ raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany.")
231
+
232
+ all_results = []
233
+ valid_images_pil = []
234
+ valid_tensors = []
235
+ original_indices = []
236
+
237
+ for idx, file in enumerate(files):
238
+ try:
239
+ contents = await file.read()
240
+ image_pil_original = Image.open(io.BytesIO(contents))
241
+
242
+ ood_error_details = run_heuristic_ood_checks(
243
+ image_pil_original.copy(), request_id,
244
+ colorfulness_threshold, uniformity_threshold, aspect_ratio_min, aspect_ratio_max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  )
246
+
247
+ if ood_error_details:
248
+ # Rozpoznajemy nasz specjalny typ b艂臋du
249
+ if ood_error_details.startswith("INVALID_IMAGE_TYPE"):
250
+ error_type = "Invalid Image Type"
251
+ interpretation = "Przes艂any plik nie wygl膮da na obraz mammograficzny. Prosz臋 wgra膰 odpowiednie zdj臋cie rentgenowskie."
252
+ details = ood_error_details.replace("INVALID_IMAGE_TYPE: ", "")
253
+ else: # Pozosta艂e b艂臋dy heurystyczne
254
+ error_type = "Heuristic OOD check failed"
255
+ interpretation = "Obraz odrzucony przez wst臋pne testy. Mo偶e mie膰 nietypowe wymiary lub by膰 zbyt jednolity."
256
+ details = ood_error_details.replace("HEURISTIC_FAILED: ", "")
257
+
258
+ result = PredictionResult(
259
+ interpretation=interpretation,
260
+ class_probabilities={}, error=error_type,
261
+ details=details
262
+ )
263
+ all_results.append((idx, result))
264
+ continue
265
+
266
+ image_rgb = image_pil_original.convert("RGB")
267
+ input_tensor = transform_pipeline(image_rgb).unsqueeze(0).to(DEVICE)
268
+
269
+ valid_images_pil.append(image_rgb)
270
+ valid_tensors.append(input_tensor)
271
+ original_indices.append(idx)
272
+
273
+ except Exception as e:
274
+ logger.error(f"[RequestID: {request_id}] B艂膮d podczas odczytu pliku {file.filename}: {e}", exc_info=True)
275
+ result = PredictionResult(
276
+ interpretation="B艂膮d podczas przetwarzania pliku.", class_probabilities={},
277
+ error="File processing error.", details=str(e)
278
+ )
279
+ all_results.append((idx, result))
280
+
281
+ if valid_tensors:
282
+ batch_tensor = torch.cat(valid_tensors, dim=0)
283
+ logger.info(f"[RequestID: {request_id}] Przetwarzanie wsadu {batch_tensor.shape[0]} poprawnych obraz贸w.")
284
+
285
+ mc_results = predict_with_mc_dropout(model_instance, batch_tensor, mc_dropout_samples, uncertainty_threshold_std)
286
+
287
+ model_instance.eval()
288
  target_layers = [model_instance.layer4[-1]]
289
  cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers)
290
 
291
+ for i, result_dict in enumerate(mc_results):
292
+ if not result_dict.get("error"):
293
+ birads_cat = result_dict["birads"]
294
+ pred_idx = result_dict["predicted_class_index"]
295
+
296
+ input_tensor_for_cam = batch_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True)
297
+ targets_for_cam = [ClassifierOutputTarget(pred_idx)]
298
+
299
+ grayscale_cam = cam_algorithm(input_tensor=input_tensor_for_cam, targets=targets_for_cam)
300
+
301
+ if grayscale_cam is not None:
302
+ overlay_image_pil = create_grad_cam_overlay_image(
303
+ original_pil_image=valid_images_pil[i],
304
+ grayscale_cam=grayscale_cam[0, :],
305
+ birads_category=birads_cat
306
+ )
307
+ if overlay_image_pil:
308
+ buffered = io.BytesIO()
309
+ overlay_image_pil.save(buffered, format="PNG")
310
+ result_dict["grad_cam_image_base64"] = base64.b64encode(buffered.getvalue()).decode('utf-8')
311
+
312
+ result_dict.pop("predicted_class_index", None)
313
+ all_results.append((original_indices[i], PredictionResult(**result_dict)))
314
+
315
+ all_results.sort(key=lambda x: x[0])
316
+ final_results = [res for _, res in all_results]
317
+
318
+ return final_results
 
 
319
 
320
  @app.get("/")
321
  async def root():
322
  logger.info("Otrzymano 偶膮danie GET na /")
323
+ return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."}