Enterwar99 commited on
Commit
768cd43
·
verified ·
1 Parent(s): fe7534d

Update api_app.py

Browse files
Files changed (1) hide show
  1. api_app.py +39 -28
api_app.py CHANGED
@@ -31,14 +31,7 @@ MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium
31
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
33
  IMAGENET_STD = [0.229, 0.224, 0.225]
34
-
35
- # --- Konfigurowalne progi (ze zmiennych środowiskowych z wartościami domyślnymi) ---
36
- COLORFULNESS_THRESHOLD = float(os.environ.get("COLORFULNESS_THRESHOLD", 15))
37
- UNIFORMITY_THRESHOLD = float(os.environ.get("UNIFORMITY_THRESHOLD", 10))
38
- ASPECT_RATIO_MIN = float(os.environ.get("ASPECT_RATIO_MIN", 0.4))
39
- ASPECT_RATIO_MAX = float(os.environ.get("ASPECT_RATIO_MAX", 2.5))
40
- MC_DROPOUT_SAMPLES = int(os.environ.get("MC_DROPOUT_SAMPLES", 25))
41
- UNCERTAINTY_THRESHOLD_STD = float(os.environ.get("UNCERTAINTY_THRESHOLD_STD", 0.08))
42
 
43
  # Globalne zmienne dla modelu i transformacji
44
  model_instance = None
@@ -98,11 +91,11 @@ def initialize_model():
98
  logger.info(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
99
 
100
  # --- Funkcja do predykcji z kwantyfikacją niepewności (MC Dropout) ---
101
- def predict_with_mc_dropout(current_model_instance, input_tensor_on_device): # Usunięto domyślne wartości, będą brane z globalnych zmiennych
102
  """
103
  Wykonuje predykcję z użyciem Monte Carlo Dropout do oszacowania niepewności.
104
  """
105
- logger.info(f"Performing MC Dropout with {MC_DROPOUT_SAMPLES} samples. Uncertainty threshold (std): {UNCERTAINTY_THRESHOLD_STD}")
106
 
107
  original_mode_is_training = current_model_instance.training
108
  current_model_instance.train() # Włącz warstwy dropout
@@ -110,7 +103,7 @@ def predict_with_mc_dropout(current_model_instance, input_tensor_on_device): # U
110
  all_probs_list = []
111
  with torch.no_grad(): # Gradienty nie są potrzebne do samego przejścia w przód
112
  for _ in range(n_samples):
113
- output = current_model_instance(input_tensor_on_device) # Użyj MC_DROPOUT_SAMPLES
114
  probs_tensor = torch.nn.functional.softmax(output, dim=1)
115
  all_probs_list.append(probs_tensor.cpu().numpy())
116
 
@@ -128,7 +121,7 @@ def predict_with_mc_dropout(current_model_instance, input_tensor_on_device): # U
128
  # Użyj średniej odchyleń standardowych prawdopodobieństw wszystkich klas jako metryki niepewności
129
  uncertainty_metric = np.mean(std_dev_probabilities)
130
 
131
- is_uncertain = uncertainty_metric > UNCERTAINTY_THRESHOLD_STD # Użyj UNCERTAINTY_THRESHOLD_STD
132
  logger.info(f"MC Dropout Results: Predicted Index: {predicted_class_index}, Confidence: {confidence_in_predicted_class:.4f}, Uncertainty (avg_std): {uncertainty_metric:.4f}, Is Uncertain: {is_uncertain}")
133
 
134
  birads_category_if_confident = predicted_class_index + 1
@@ -140,7 +133,7 @@ def predict_with_mc_dropout(current_model_instance, input_tensor_on_device): # U
140
  "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))},
141
  "grad_cam_image_base64": None, # Zostanie wypełnione później, jeśli pewne
142
  "error": "High prediction uncertainty" if is_uncertain else None,
143
- "details": f"Uncertainty metric ({uncertainty_metric:.4f}) {'przekroczyła' if is_uncertain else 'jest w granicach'} progu ({UNCERTAINTY_THRESHOLD_STD})."
144
  }
145
  return base_result, predicted_class_index # Zwróć również indeks dla Grad-CAM
146
 
@@ -187,10 +180,9 @@ def create_grad_cam_overlay_image(original_pil_image: Image.Image, grayscale_cam
187
  return None
188
 
189
  # --- Funkcja do heurystycznych testów OOD ---
190
- def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str) -> bool:
191
  """
192
  Wykonuje zestaw prostych heurystyk do wykrywania obrazów spoza dystrybucji.
193
- Zwraca True, jeśli obraz przeszedł testy (prawdopodobnie jest OK), False jeśli nie.
194
  """
195
  logger.info(f"[RequestID: {request_id}] Uruchamianie heurystycznych testów OOD...")
196
  width, height = pil_image.size
@@ -198,8 +190,8 @@ def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str) -> bool:
198
  # Heurystyka 1: Sprawdzenie proporcji obrazu
199
  aspect_ratio = width / height
200
  if not (ASPECT_RATIO_MIN < aspect_ratio < ASPECT_RATIO_MAX): # Użyj zmiennych konfiguracyjnych
201
- logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Nietypowe proporcje obrazu: {aspect_ratio:.2f}. Odrzucam.")
202
- return False
203
 
204
  # Heurystyka 2: Analiza "kolorowości" (dla obrazów RGB)
205
  # Mammografie są w skali szarości; jeśli obraz jest kolorowy, ta metryka będzie wysoka.
@@ -210,7 +202,7 @@ def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str) -> bool:
210
  img_np_rgb = np.array(img_rgb_for_color_check)
211
  std_per_pixel_across_channels = np.std(img_np_rgb, axis=2) # Odch. std. dla każdego piksela po kanałach R,G,B
212
  mean_std_across_channels = np.mean(std_per_pixel_across_channels)
213
- if mean_std_across_channels > COLORFULNESS_THRESHOLD: # Użyj zmiennej konfiguracyjnej
214
  logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Obraz wydaje się zbyt kolorowy. Średnie odch. std. między kanałami: {mean_std_across_channels:.2f}. Odrzucam.")
215
  return False
216
 
@@ -218,7 +210,7 @@ def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str) -> bool:
218
  # Konwertujemy do skali szarości dla tej analizy
219
  gray_image = pil_image.convert('L')
220
  std_dev_intensity = np.std(np.array(gray_image))
221
- if std_dev_intensity < UNIFORMITY_THRESHOLD: # Użyj zmiennej konfiguracyjnej
222
  logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Obraz wydaje się zbyt jednolity (mało zróżnicowania jasności). Odch. std. intensywności: {std_dev_intensity:.2f}. Odrzucam.")
223
  return False
224
 
@@ -246,14 +238,22 @@ async def startup_event():
246
  initialize_model()
247
 
248
  @app.post("/predict/", response_model=List[PredictionResult]) # Użycie modelu Pydantic
249
- async def predict_image(file: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
250
  """
251
  Endpoint do klasyfikacji obrazu mammograficznego.
252
- Oczekuje pliku obrazu (JPG, PNG).
253
- Zwraca listę z wynikami (nawet jeśli tylko jeden obraz).
254
  """
255
  request_id = os.urandom(8).hex() # Prosty identyfikator żądania
256
  logger.info(f"[RequestID: {request_id}] Otrzymano żądanie /predict/")
 
257
 
258
  if model_instance is None or transform_pipeline is None:
259
  logger.error(f"[RequestID: {request_id}] Model nie jest zainicjalizowany podczas żądania /predict/.")
@@ -269,14 +269,25 @@ async def predict_image(file: UploadFile = File(...)):
269
  # --- Etap: Heurystyczne testy OOD (PRE-FILTR) ---
270
  # Używamy .copy(), aby uniknąć potencjalnych modyfikacji oryginalnego obiektu image_pil_original
271
  # przez funkcję run_heuristic_ood_checks, jeśli by takie wykonywała (np. konwersje inplace).
272
- if not run_heuristic_ood_checks(image_pil_original.copy(), request_id):
 
 
 
 
273
  # Heurystyki wykryły problem
274
  return JSONResponse(
275
  status_code=400, # Bad Request
276
- content=[{ # API oczekuje listy wyników
277
- "error": "Image does not appear to be a valid medical mammogram based on initial checks.",
278
- "details": "Heurystyczne testy OOD nie powiodły się. Obraz odrzucony przed analizą przez model AI."
279
- }]
 
 
 
 
 
 
 
280
  )
281
  image = image_pil_original.convert("RGB") # Teraz konwertuj do RGB dla modelu głównego
282
  except Exception as e:
@@ -293,7 +304,7 @@ async def predict_image(file: UploadFile = File(...)):
293
  mc_output_dict, predicted_idx_from_mc = predict_with_mc_dropout(
294
  model_instance,
295
  input_tensor,
296
- # n_samples i uncertainty_threshold_std teraz brane z globalnych zmiennych
297
  )
298
 
299
  results = []
 
31
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
33
  IMAGENET_STD = [0.229, 0.224, 0.225]
34
+ # Progi będą teraz przekazywane jako parametry zapytania
 
 
 
 
 
 
 
35
 
36
  # Globalne zmienne dla modelu i transformacji
37
  model_instance = None
 
91
  logger.info(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
92
 
93
  # --- Funkcja do predykcji z kwantyfikacją niepewności (MC Dropout) ---
94
+ def predict_with_mc_dropout(current_model_instance, input_tensor_on_device, mc_dropout_samples: int, uncertainty_threshold_std: float): # Dodano parametry
95
  """
96
  Wykonuje predykcję z użyciem Monte Carlo Dropout do oszacowania niepewności.
97
  """
98
+ logger.info(f"Performing MC Dropout with {mc_dropout_samples} samples. Uncertainty threshold (std): {uncertainty_threshold_std}") # Użyj parametrów
99
 
100
  original_mode_is_training = current_model_instance.training
101
  current_model_instance.train() # Włącz warstwy dropout
 
103
  all_probs_list = []
104
  with torch.no_grad(): # Gradienty nie są potrzebne do samego przejścia w przód
105
  for _ in range(n_samples):
106
+ output = current_model_instance(input_tensor_on_device) # Użyj mc_dropout_samples
107
  probs_tensor = torch.nn.functional.softmax(output, dim=1)
108
  all_probs_list.append(probs_tensor.cpu().numpy())
109
 
 
121
  # Użyj średniej odchyleń standardowych prawdopodobieństw wszystkich klas jako metryki niepewności
122
  uncertainty_metric = np.mean(std_dev_probabilities)
123
 
124
+ is_uncertain = uncertainty_metric > uncertainty_threshold_std # Użyj uncertainty_threshold_std
125
  logger.info(f"MC Dropout Results: Predicted Index: {predicted_class_index}, Confidence: {confidence_in_predicted_class:.4f}, Uncertainty (avg_std): {uncertainty_metric:.4f}, Is Uncertain: {is_uncertain}")
126
 
127
  birads_category_if_confident = predicted_class_index + 1
 
133
  "class_probabilities": {str(j + 1): float(mean_probabilities[j]) for j in range(len(mean_probabilities))},
134
  "grad_cam_image_base64": None, # Zostanie wypełnione później, jeśli pewne
135
  "error": "High prediction uncertainty" if is_uncertain else None,
136
+ "details": f"Uncertainty metric ({uncertainty_metric:.4f}) {'przekroczyła' if is_uncertain else 'jest w granicach'} progu ({uncertainty_threshold_std})." # Użyj uncertainty_threshold_std
137
  }
138
  return base_result, predicted_class_index # Zwróć również indeks dla Grad-CAM
139
 
 
180
  return None
181
 
182
  # --- Funkcja do heurystycznych testów OOD ---
183
+ def run_heuristic_ood_checks(pil_image: Image.Image, request_id: str, colorfulness_threshold: float, uniformity_threshold: float, aspect_ratio_min: float, aspect_ratio_max: float) -> bool: # Dodano parametry
184
  """
185
  Wykonuje zestaw prostych heurystyk do wykrywania obrazów spoza dystrybucji.
 
186
  """
187
  logger.info(f"[RequestID: {request_id}] Uruchamianie heurystycznych testów OOD...")
188
  width, height = pil_image.size
 
190
  # Heurystyka 1: Sprawdzenie proporcji obrazu
191
  aspect_ratio = width / height
192
  if not (ASPECT_RATIO_MIN < aspect_ratio < ASPECT_RATIO_MAX): # Użyj zmiennych konfiguracyjnych
193
+ if not (aspect_ratio_min < aspect_ratio < aspect_ratio_max): # Użyj parametrów
194
+ return False
195
 
196
  # Heurystyka 2: Analiza "kolorowości" (dla obrazów RGB)
197
  # Mammografie są w skali szarości; jeśli obraz jest kolorowy, ta metryka będzie wysoka.
 
202
  img_np_rgb = np.array(img_rgb_for_color_check)
203
  std_per_pixel_across_channels = np.std(img_np_rgb, axis=2) # Odch. std. dla każdego piksela po kanałach R,G,B
204
  mean_std_across_channels = np.mean(std_per_pixel_across_channels)
205
+ if mean_std_across_channels > colorfulness_threshold: # Użyj parametru
206
  logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Obraz wydaje się zbyt kolorowy. Średnie odch. std. między kanałami: {mean_std_across_channels:.2f}. Odrzucam.")
207
  return False
208
 
 
210
  # Konwertujemy do skali szarości dla tej analizy
211
  gray_image = pil_image.convert('L')
212
  std_dev_intensity = np.std(np.array(gray_image))
213
+ if std_dev_intensity < uniformity_threshold: # Użyj parametru
214
  logger.warning(f"[RequestID: {request_id}] Heurystyka OOD: Obraz wydaje się zbyt jednolity (mało zróżnicowania jasności). Odch. std. intensywności: {std_dev_intensity:.2f}. Odrzucam.")
215
  return False
216
 
 
238
  initialize_model()
239
 
240
  @app.post("/predict/", response_model=List[PredictionResult]) # Użycie modelu Pydantic
241
+ # Dodano parametry zapytania z wartościami domyślnymi
242
+ async def predict_image(
243
+ file: UploadFile = File(...),
244
+ colorfulness_threshold: float = 15.0,
245
+ uniformity_threshold: float = 10.0,
246
+ aspect_ratio_min: float = 0.4,
247
+ aspect_ratio_max: float = 2.5,
248
+ mc_dropout_samples: int = 25,
249
+ uncertainty_threshold_std: float = 0.08
250
+ ):
251
  """
252
  Endpoint do klasyfikacji obrazu mammograficznego.
 
 
253
  """
254
  request_id = os.urandom(8).hex() # Prosty identyfikator żądania
255
  logger.info(f"[RequestID: {request_id}] Otrzymano żądanie /predict/")
256
+ logger.info(f"[RequestID: {request_id}] Received parameters: colorfulness_threshold={colorfulness_threshold}, uniformity_threshold={uniformity_threshold}, aspect_ratio_min={aspect_ratio_min}, aspect_ratio_max={aspect_ratio_max}, mc_dropout_samples={mc_dropout_samples}, uncertainty_threshold_std={uncertainty_threshold_std}")
257
 
258
  if model_instance is None or transform_pipeline is None:
259
  logger.error(f"[RequestID: {request_id}] Model nie jest zainicjalizowany podczas żądania /predict/.")
 
269
  # --- Etap: Heurystyczne testy OOD (PRE-FILTR) ---
270
  # Używamy .copy(), aby uniknąć potencjalnych modyfikacji oryginalnego obiektu image_pil_original
271
  # przez funkcję run_heuristic_ood_checks, jeśli by takie wykonywała (np. konwersje inplace).
272
+ # Przekazujemy parametry do funkcji
273
+ if not run_heuristic_ood_checks(
274
+ image_pil_original.copy(), request_id,
275
+ colorfulness_threshold, uniformity_threshold, aspect_ratio_min, aspect_ratio_max
276
+ ):
277
  # Heurystyki wykryły problem
278
  return JSONResponse(
279
  status_code=400, # Bad Request
280
+ content=[
281
+ {
282
+ "interpretation": "Obraz odrzucony przez wstępne testy heurystyczne. Nie wygląda na poprawny obraz medyczny.",
283
+ "class_probabilities": {}, # Puste, bo nie było predykcji klas
284
+ "error": "Image does not appear to be a valid medical mammogram based on initial checks.",
285
+ "details": "Heurystyczne testy OOD nie powiodły się. Obraz odrzucony przed analizą przez model AI.",
286
+ "birads": None,
287
+ "confidence": None,
288
+ "grad_cam_image_base64": None
289
+ }
290
+ ]
291
  )
292
  image = image_pil_original.convert("RGB") # Teraz konwertuj do RGB dla modelu głównego
293
  except Exception as e:
 
304
  mc_output_dict, predicted_idx_from_mc = predict_with_mc_dropout(
305
  model_instance,
306
  input_tensor,
307
+ mc_dropout_samples, uncertainty_threshold_std # Przekazujemy parametry
308
  )
309
 
310
  results = []