Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ from transformers import ViTForImageClassification
|
|
| 8 |
from torch import nn
|
| 9 |
from torch.cuda.amp import autocast
|
| 10 |
import os
|
|
|
|
| 11 |
|
| 12 |
# Global configuration
|
| 13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -19,7 +20,7 @@ label_mapping = {
|
|
| 19 |
1: "Меланоцитарный невус",
|
| 20 |
2: "Базальноклеточная карцинома",
|
| 21 |
3: "Актинический кератоз",
|
| 22 |
-
4: "Доброкачественная керато
|
| 23 |
5: "Дерматофиброма",
|
| 24 |
6: "Сосудистые поражения"
|
| 25 |
}
|
|
@@ -88,7 +89,9 @@ class ModelHandler:
|
|
| 88 |
return {"error": "Модели не загружены"}
|
| 89 |
|
| 90 |
inputs = transform_image(image)
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
outputs = self.efficientnet(inputs)
|
| 93 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
| 94 |
|
|
@@ -100,8 +103,9 @@ class ModelHandler:
|
|
| 100 |
return {"error": "Модели не загружены"}
|
| 101 |
|
| 102 |
inputs = transform_image(image)
|
| 103 |
-
|
| 104 |
-
|
|
|
|
| 105 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
| 106 |
|
| 107 |
return self._format_predictions(probs)
|
|
@@ -112,23 +116,22 @@ class ModelHandler:
|
|
| 112 |
return {"error": "Модели не загружены"}
|
| 113 |
|
| 114 |
inputs = transform_image(image)
|
| 115 |
-
|
|
|
|
| 116 |
eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
|
| 117 |
-
deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1)
|
| 118 |
ensemble_probs = (eff_probs + deit_probs) / 2
|
| 119 |
|
| 120 |
return self._format_predictions(ensemble_probs)
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
return result
|
| 131 |
-
|
| 132 |
|
| 133 |
# Initialize model handler
|
| 134 |
model_handler = ModelHandler()
|
|
@@ -184,4 +187,4 @@ def create_interface():
|
|
| 184 |
if __name__ == "__main__":
|
| 185 |
interface = create_interface()
|
| 186 |
print("🚀 Запуск интерфейса...")
|
| 187 |
-
interface.launch(
|
|
|
|
| 8 |
from torch import nn
|
| 9 |
from torch.cuda.amp import autocast
|
| 10 |
import os
|
| 11 |
+
from contextlib import nullcontext
|
| 12 |
|
| 13 |
# Global configuration
|
| 14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 20 |
1: "Меланоцитарный невус",
|
| 21 |
2: "Базальноклеточная карцинома",
|
| 22 |
3: "Актинический кератоз",
|
| 23 |
+
4: "Доброкачественная кератома",
|
| 24 |
5: "Дерматофиброма",
|
| 25 |
6: "Сосудистые поражения"
|
| 26 |
}
|
|
|
|
| 89 |
return {"error": "Модели не загружены"}
|
| 90 |
|
| 91 |
inputs = transform_image(image)
|
| 92 |
+
# Handle autocast based on device
|
| 93 |
+
ctx = autocast() if device.type == 'cuda' else nullcontext()
|
| 94 |
+
with ctx:
|
| 95 |
outputs = self.efficientnet(inputs)
|
| 96 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
| 97 |
|
|
|
|
| 103 |
return {"error": "Модели не загружены"}
|
| 104 |
|
| 105 |
inputs = transform_image(image)
|
| 106 |
+
ctx = autocast() if device.type == 'cuda' else nullcontext()
|
| 107 |
+
with ctx:
|
| 108 |
+
outputs = self.deit(pixel_values=inputs).logits # Corrected parameter
|
| 109 |
probs = torch.nn.functional.softmax(outputs, dim=1)
|
| 110 |
|
| 111 |
return self._format_predictions(probs)
|
|
|
|
| 116 |
return {"error": "Модели не загружены"}
|
| 117 |
|
| 118 |
inputs = transform_image(image)
|
| 119 |
+
ctx = autocast() if device.type == 'cuda' else nullcontext()
|
| 120 |
+
with ctx:
|
| 121 |
eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
|
| 122 |
+
deit_probs = torch.nn.functional.softmax(self.deit(pixel_values=inputs).logits, dim=1)
|
| 123 |
ensemble_probs = (eff_probs + deit_probs) / 2
|
| 124 |
|
| 125 |
return self._format_predictions(ensemble_probs)
|
| 126 |
|
| 127 |
+
def _format_predictions(self, probs): # Corrected indentation
|
| 128 |
+
top5_probs, top5_indices = torch.topk(probs, 5)
|
| 129 |
+
result = {}
|
| 130 |
+
for i in range(5):
|
| 131 |
+
idx = top5_indices[0][i].item()
|
| 132 |
+
label = label_mapping.get(idx, f"Класс {idx}")
|
| 133 |
+
result[label] = float(top5_probs[0][i].item())
|
| 134 |
+
return result
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# Initialize model handler
|
| 137 |
model_handler = ModelHandler()
|
|
|
|
| 187 |
if __name__ == "__main__":
|
| 188 |
interface = create_interface()
|
| 189 |
print("🚀 Запуск интерфейса...")
|
| 190 |
+
interface.launch(server_port=7860) # Explicitly set port if needed
|