Eraly-ml commited on
Commit
5abc557
·
verified ·
1 Parent(s): 7b665a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -98
app.py CHANGED
@@ -5,132 +5,109 @@ from PIL import Image
5
  import os
6
  from typing import Tuple, List, Dict
7
 
8
- # Устройство
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
- # Загрузка модели
12
  def load_model() -> Tuple[torch.nn.Module, List[str]]:
 
 
 
 
 
 
13
  model_path = "skinconvnext_scripted.pt"
14
  labels_path = "labels.txt"
 
15
  if not os.path.exists(model_path):
16
  raise FileNotFoundError(f"Model not found: {model_path}")
17
  if not os.path.exists(labels_path):
18
- raise FileNotFoundError("labels.txt not found.")
 
19
  model = torch.jit.load(model_path, map_location=device)
20
- model.eval().to(device)
 
 
21
  with open(labels_path, "r") as f:
22
  labels = f.read().splitlines()
 
23
  return model, labels
24
 
25
  model, labels = load_model()
26
 
27
- # Преобразования
28
  preprocess = transforms.Compose([
29
  transforms.Resize((224, 224)),
30
  transforms.ToTensor(),
31
- transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
 
32
  ])
33
 
34
- # Функция предсказаний
35
  def predict(image: Image.Image) -> Dict[str, float]:
 
 
 
 
 
 
 
 
36
  try:
37
- img = image.convert("RGB")
38
- tensor = preprocess(img).unsqueeze(0).to(device)
 
 
39
  with torch.no_grad():
40
- out = model(tensor)
41
- probs = torch.nn.functional.softmax(out[0], dim=0)
42
- preds = {lbl: float(p) for lbl,p in zip(labels, probs)}
43
- return dict(sorted(preds.items(), key=lambda x: x[1], reverse=True))
 
 
 
 
44
  except Exception as e:
45
  return {"error": str(e)}
46
 
47
- # Примеры
48
- examples = [["example1.jpg"], ["example2.jpg"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Переводы
51
- translations = {
52
- "Русский": {
53
- "title": "# 🔥 Skin-AI",
54
- "description": (
55
- "🔬 **Skin-AI — AI для классификации кожных заболеваний**\n\n"
56
- "1️⃣ Загрузите или выберите пример ниже.\n"
57
- "2️⃣ Нажмите кнопку «Предсказать».\n"
58
- "⚠️ **Не является медицинской диагностикой.**"
59
- ),
60
- "upload_label": "Загрузить изображение",
61
- "submit": "Предсказать",
62
- "result": "Результат"
63
- },
64
- "Қазақша": {
65
- "title": "# 🔥 Skin-AI",
66
- "description": (
67
- "🔬 **Skin-AI — Терi ауруларын анықтайтын ИИ**\n\n"
68
- "1️⃣ Суретті жүктеңіз не төмендегі мысалдарды таңдаңыз.\n"
69
- "2️⃣ «Болжам жасау» батырмасын басыңыз.\n"
70
- "⚠️ **Бұл медициналық диагноз емес.**"
71
- ),
72
- "upload_label": "Суретті жүктеу",
73
- "submit": "Болжам жасау",
74
- "result": "Нәтиже"
75
- },
76
- "English": {
77
- "title": "# 🔥 Skin-AI",
78
- "description": (
79
- "🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
80
- "1️⃣ Upload an image or choose an example below.\n"
81
- "2️⃣ Click the “Submit” button.\n"
82
- "⚠️ **This is not a medical diagnosis.**"
83
- ),
84
- "upload_label": "Upload Image",
85
- "submit": "Submit",
86
- "result": "Prediction"
87
- }
88
- }
89
 
90
- # Функция переключения языка
91
- def change_language(lang):
92
- tr = translations[lang]
93
- return (
94
- gr.update(value=tr["title"]),
95
- gr.update(value=tr["description"]),
96
- gr.update(label=tr["upload_label"]),
97
- gr.update(value=tr["submit"]),
98
- gr.update(label=tr["result"])
99
- )
100
 
101
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
102
- # Селектор языка
103
- lang = gr.Dropdown(
104
- choices=["Русский","Қазақша","English"],
105
- value="Русский",
106
- label="🌐 Язык / Language / Тiлдер"
107
- )
108
-
109
- # Заголовок и описание
110
- title_md = gr.Markdown(translations["Русский"]["title"])
111
- desc_md = gr.Markdown(translations["Русский"]["description"])
112
-
113
- # Основные компоненты
114
- image_input = gr.Image(type="pil", label=translations["Русский"]["upload_label"])
115
- submit_button = gr.Button(translations["Русский"]["submit"])
116
- output_label = gr.Label(num_top_classes=3, label=translations["Русский"]["result"])
117
-
118
- # Примеры
119
- gr.Examples(
120
- examples=examples,
121
- inputs=image_input,
122
- outputs=output_label,
123
- fn=predict,
124
- cache_examples=True
125
- )
126
-
127
- # Связи
128
- lang.change(
129
- fn=change_language,
130
- inputs=lang,
131
- outputs=[title_md, desc_md, image_input, submit_button, output_label]
132
- )
133
  submit_button.click(fn=predict, inputs=image_input, outputs=output_label)
134
-
 
 
135
  if __name__ == "__main__":
136
- demo.launch()
 
5
  import os
6
  from typing import Tuple, List, Dict
7
 
8
+ # Единоразовое определение устройства
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
 
11
  def load_model() -> Tuple[torch.nn.Module, List[str]]:
12
+ """
13
+ Загружает модель и список меток классов.
14
+ Returns:
15
+ model: Загруженная модель PyTorch.
16
+ labels: Список меток классов.
17
+ """
18
  model_path = "skinconvnext_scripted.pt"
19
  labels_path = "labels.txt"
20
+
21
  if not os.path.exists(model_path):
22
  raise FileNotFoundError(f"Model not found: {model_path}")
23
  if not os.path.exists(labels_path):
24
+ raise FileNotFoundError("File labels.txt not found.")
25
+
26
  model = torch.jit.load(model_path, map_location=device)
27
+ model.eval()
28
+ model.to(device) # Перемещаем модель на устройство сразу
29
+
30
  with open(labels_path, "r") as f:
31
  labels = f.read().splitlines()
32
+
33
  return model, labels
34
 
35
  model, labels = load_model()
36
 
37
+ # Определение преобразований изображения (создаётся один раз)
38
  preprocess = transforms.Compose([
39
  transforms.Resize((224, 224)),
40
  transforms.ToTensor(),
41
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
42
+ std=[0.229, 0.224, 0.225]),
43
  ])
44
 
 
45
  def predict(image: Image.Image) -> Dict[str, float]:
46
+ """
47
+ Выполняет предсказание для переданного изображения.
48
+ Args:
49
+ image (PIL.Image): Входное изображение.
50
+
51
+ Returns:
52
+ Dict[str, float]: Словарь, где ключи – имена классов, а значения – вероятности.
53
+ """
54
  try:
55
+ # Приводим изображение к RGB и преобразуем
56
+ image = image.convert("RGB")
57
+ image_tensor = preprocess(image).unsqueeze(0).to(device)
58
+
59
  with torch.no_grad():
60
+ output = model(image_tensor)
61
+ # Предполагаем, что output имеет размерность [1, N]
62
+ scores = torch.nn.functional.softmax(output[0], dim=0)
63
+
64
+ # Формируем словарь с предсказаниями
65
+ predictions = {label: float(score) for label, score in zip(labels, scores)}
66
+ sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
67
+ return sorted_predictions
68
  except Exception as e:
69
  return {"error": str(e)}
70
 
71
+ # Описание интерфейса Gradio
72
+ title = "🔥 Skin-AI"
73
+ description = (
74
+ "🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
75
+ "Проект использует глубокую модель для классификации заболеваний кожи по изображению.\n\n"
76
+ "### 🚀 Как использовать:\n\n"
77
+ "1️⃣ Загрузите или сфотографируйте поражённый участок кожи.\n\n"
78
+ "2️⃣ Нажмите кнопку 'Submit'.\n\n"
79
+ "3️⃣ Приложение покажет вероятности для возможных заболеваний.\n\n"
80
+ "⚠️ **Внимание!** Результаты предоставлены в ознакомительных целях и не являются медицинской диагностикой.\n\n"
81
+ "### 🛠 Используемые технологии:\n"
82
+ "- PyTorch\n"
83
+ "- Gradio\n"
84
+ "- Hugging Face Spaces\n\n"
85
+ "🔗 Исходный код: [Hugging Face](https://huggingface.co/Eraly-ml/Skin-AI)"
86
+ )
87
 
88
+ # Примеры изображений
89
+ examples = [
90
+ ["example1.jpg"],
91
+ ["example2.jpg"]
92
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ def update_submit_state(image):
95
+ return gr.update(interactive=image is not None)
 
 
 
 
 
 
 
 
96
 
97
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
98
+ gr.Markdown(title)
99
+ gr.Markdown(description)
100
+
101
+ with gr.Row():
102
+ image_input = gr.Image(type="pil", label="Upload Image")
103
+
104
+ submit_button = gr.Button("Submit", interactive=False)
105
+ output_label = gr.Label(num_top_classes=3, label="Prediction")
106
+
107
+ image_input.change(fn=update_submit_state, inputs=image_input, outputs=submit_button)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  submit_button.click(fn=predict, inputs=image_input, outputs=output_label)
109
+
110
+ gr.Examples(examples, inputs=image_input)
111
+
112
  if __name__ == "__main__":
113
+ interface.launch()