Eraly-ml commited on
Commit
71667c6
·
verified ·
1 Parent(s): 5abc557

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -27
app.py CHANGED
@@ -5,15 +5,12 @@ from PIL import Image
5
  import os
6
  from typing import Tuple, List, Dict
7
 
8
- # Единоразовое определение устройства
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
-
11
  def load_model() -> Tuple[torch.nn.Module, List[str]]:
12
  """
13
- Загружает модель и список меток классов.
14
  Returns:
15
- model: Загруженная модель PyTorch.
16
- labels: Список меток классов.
17
  """
18
  model_path = "skinconvnext_scripted.pt"
19
  labels_path = "labels.txt"
@@ -23,18 +20,18 @@ def load_model() -> Tuple[torch.nn.Module, List[str]]:
23
  if not os.path.exists(labels_path):
24
  raise FileNotFoundError("File labels.txt not found.")
25
 
 
26
  model = torch.jit.load(model_path, map_location=device)
27
  model.eval()
28
- model.to(device) # Перемещаем модель на устройство сразу
29
 
30
  with open(labels_path, "r") as f:
31
- labels = f.read().splitlines()
32
 
33
  return model, labels
34
 
35
  model, labels = load_model()
36
 
37
- # Определение преобразований изображения (создаётся один раз)
38
  preprocess = transforms.Compose([
39
  transforms.Resize((224, 224)),
40
  transforms.ToTensor(),
@@ -44,48 +41,48 @@ preprocess = transforms.Compose([
44
 
45
  def predict(image: Image.Image) -> Dict[str, float]:
46
  """
47
- Выполняет предсказание для переданного изображения.
48
  Args:
49
- image (PIL.Image): Входное изображение.
50
 
51
  Returns:
52
- Dict[str, float]: Словарь, где ключи имена классов, а значения вероятности.
53
  """
54
  try:
55
- # Приводим изображение к RGB и преобразуем
56
  image = image.convert("RGB")
57
- image_tensor = preprocess(image).unsqueeze(0).to(device)
 
 
 
58
 
59
  with torch.no_grad():
60
  output = model(image_tensor)
61
- # Предполагаем, что output имеет размерность [1, N]
62
  scores = torch.nn.functional.softmax(output[0], dim=0)
63
 
64
- # Формируем словарь с предсказаниями
65
  predictions = {label: float(score) for label, score in zip(labels, scores)}
66
  sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
 
67
  return sorted_predictions
68
  except Exception as e:
69
  return {"error": str(e)}
70
 
71
- # Описание интерфейса Gradio
72
  title = "🔥 Skin-AI"
73
  description = (
74
  "🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
75
- "Проект использует глубокую модель для классификации заболеваний кожи по изображению.\n\n"
76
- "### 🚀 Как использовать:\n\n"
77
- "1️⃣ Загрузите или сфотографируйте поражённый участок кожи.\n\n"
78
- "2️⃣ Нажмите кнопку 'Submit'.\n\n"
79
- "3️⃣ Приложение покажет вероятности для возможных заболеваний.\n\n"
80
- "⚠️ **Внимание!** Результаты предоставлены в ознакомительных целях и не являются медицинской диагностикой.\n\n"
81
- "### 🛠 Используемые технологии:\n"
82
- "- PyTorch\n"
83
  "- Gradio\n"
84
  "- Hugging Face Spaces\n\n"
85
- "🔗 Исходный код: [Hugging Face](https://huggingface.co/Eraly-ml/Skin-AI)"
86
  )
87
 
88
- # Примеры изображений
89
  examples = [
90
  ["example1.jpg"],
91
  ["example2.jpg"]
 
5
  import os
6
  from typing import Tuple, List, Dict
7
 
 
 
 
8
  def load_model() -> Tuple[torch.nn.Module, List[str]]:
9
  """
10
+ Loads the model and class labels.
11
  Returns:
12
+ model: The loaded PyTorch model.
13
+ labels: List of class labels.
14
  """
15
  model_path = "skinconvnext_scripted.pt"
16
  labels_path = "labels.txt"
 
20
  if not os.path.exists(labels_path):
21
  raise FileNotFoundError("File labels.txt not found.")
22
 
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  model = torch.jit.load(model_path, map_location=device)
25
  model.eval()
 
26
 
27
  with open(labels_path, "r") as f:
28
+ labels = [line.strip() for line in f.readlines()]
29
 
30
  return model, labels
31
 
32
  model, labels = load_model()
33
 
34
+ # Define image preprocessing steps
35
  preprocess = transforms.Compose([
36
  transforms.Resize((224, 224)),
37
  transforms.ToTensor(),
 
41
 
42
  def predict(image: Image.Image) -> Dict[str, float]:
43
  """
44
+ Makes a prediction for the given image.
45
  Args:
46
+ image (PIL.Image): The input image.
47
 
48
  Returns:
49
+ Dict[str, float]: A dictionary where keys are class names, and values are probabilities.
50
  """
51
  try:
 
52
  image = image.convert("RGB")
53
+ image_tensor = preprocess(image).unsqueeze(0)
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ image_tensor = image_tensor.to(device)
56
+ model.to(device)
57
 
58
  with torch.no_grad():
59
  output = model(image_tensor)
 
60
  scores = torch.nn.functional.softmax(output[0], dim=0)
61
 
 
62
  predictions = {label: float(score) for label, score in zip(labels, scores)}
63
  sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
64
+
65
  return sorted_predictions
66
  except Exception as e:
67
  return {"error": str(e)}
68
 
 
69
  title = "🔥 Skin-AI"
70
  description = (
71
  "🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
72
+ "This project utilizes a deep learning model to classify skin diseases based on an uploaded image.\n\n"
73
+ "### 🚀 How to Use:\n\n"
74
+ "1️⃣ Upload or take a photo of the affected skin area.\n\n"
75
+ "2️⃣ Click the 'Submit' button.\n\n"
76
+ "3️⃣ The app will return the probabilities for possible skin conditions.\n\n"
77
+ "⚠️ **Important!** The results are for informational purposes only and do not constitute a medical diagnosis.\n\n"
78
+ "### 🛠 Technologies Used:\n"
79
+ "- PyTorch (Lightning)\n"
80
  "- Gradio\n"
81
  "- Hugging Face Spaces\n\n"
82
+ "🔗 Source Code: [Hugging Face](https://huggingface.co/Eraly-ml/Skin-AI)"
83
  )
84
 
85
+ # Adding example images
86
  examples = [
87
  ["example1.jpg"],
88
  ["example2.jpg"]