Eraly-ml commited on
Commit
4b90365
·
verified ·
1 Parent(s): 9fa70c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -7,21 +7,20 @@ from typing import Tuple, List, Dict
7
 
8
  def load_model() -> Tuple[torch.nn.Module, List[str]]:
9
  """
10
- Загружает модель и метки классов.
11
 
12
  Returns:
13
- model: Загруженная PyTorch модель.
14
- labels: Список меток классов.
15
  """
16
  model_path = "skin_disease_model_jit.pt"
17
  labels_path = "labels.txt"
18
 
19
  if not os.path.exists(model_path):
20
- raise FileNotFoundError(f"Модель не найдена: {model_path}")
21
  if not os.path.exists(labels_path):
22
- raise FileNotFoundError("Файл labels.txt не найден.")
23
 
24
- # Если доступна GPU, используем её, иначе CPU.
25
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
  model = torch.jit.load(model_path, map_location=device)
27
  model.eval()
@@ -33,7 +32,7 @@ def load_model() -> Tuple[torch.nn.Module, List[str]]:
33
 
34
  model, labels = load_model()
35
 
36
- # Определение преобразований для входного изображения.
37
  preprocess = transforms.Compose([
38
  transforms.Resize((224, 224)),
39
  transforms.ToTensor(),
@@ -43,27 +42,28 @@ preprocess = transforms.Compose([
43
 
44
  def predict(image: Image.Image) -> Dict[str, float]:
45
  """
46
- Выполняет предсказание для переданного изображения.
47
 
48
  Args:
49
- image (PIL.Image): Изображение для анализа.
50
 
51
  Returns:
52
- Dict[str, float]: Словарь, где ключ класс, значение вероятность.
53
  """
54
  try:
55
  image = image.convert("RGB")
56
  image_tensor = preprocess(image).unsqueeze(0)
57
- # Определяем устройство для обработки.
58
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
  image_tensor = image_tensor.to(device)
60
  model.to(device)
 
61
  with torch.no_grad():
62
  output = model(image_tensor)
63
  scores = torch.nn.functional.softmax(output[0], dim=0)
 
64
  predictions = {label: float(score) for label, score in zip(labels, scores)}
65
- # Сортировка предсказаний по убыванию вероятности.
66
  sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
 
67
  return sorted_predictions
68
  except Exception as e:
69
  return {"error": str(e)}
@@ -73,7 +73,7 @@ description = (
73
  "🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
74
  "This project utilizes a deep learning model to classify skin diseases based on an uploaded image.\n\n"
75
  "### 🚀 How to Use:\n\n"
76
- "1️⃣ Upload or take a photo of the affected skin area. \n\n"
77
  "2️⃣ Click the 'Submit' button.\n\n"
78
  "3️⃣ The app will return the probabilities for possible skin conditions.\n\n"
79
  "⚠️ **Important!** The results are for informational purposes only and do not constitute a medical diagnosis.\n\n"
@@ -81,15 +81,22 @@ description = (
81
  "- PyTorch (Lightning)\n"
82
  "- Gradio\n"
83
  "- Hugging Face Spaces\n\n"
84
- "🔗 Source Code: [GitHub/Hugging Face](https://huggingface.co/spaces/Eraly-ml/Skin-AI)"
85
  )
86
 
 
 
 
 
 
 
87
  interface = gr.Interface(
88
  fn=predict,
89
- inputs=gr.Image(type="pil", label="Image"),
90
  outputs=gr.Label(num_top_classes=3, label="Prediction"),
91
  title=title,
92
  description=description,
 
93
  theme=gr.themes.Soft()
94
  )
95
 
 
7
 
8
  def load_model() -> Tuple[torch.nn.Module, List[str]]:
9
  """
10
+ Loads the model and class labels.
11
 
12
  Returns:
13
+ model: The loaded PyTorch model.
14
+ labels: List of class labels.
15
  """
16
  model_path = "skin_disease_model_jit.pt"
17
  labels_path = "labels.txt"
18
 
19
  if not os.path.exists(model_path):
20
+ raise FileNotFoundError(f"Model not found: {model_path}")
21
  if not os.path.exists(labels_path):
22
+ raise FileNotFoundError("File labels.txt not found.")
23
 
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model = torch.jit.load(model_path, map_location=device)
26
  model.eval()
 
32
 
33
  model, labels = load_model()
34
 
35
+ # Define image preprocessing steps
36
  preprocess = transforms.Compose([
37
  transforms.Resize((224, 224)),
38
  transforms.ToTensor(),
 
42
 
43
  def predict(image: Image.Image) -> Dict[str, float]:
44
  """
45
+ Makes a prediction for the given image.
46
 
47
  Args:
48
+ image (PIL.Image): The input image.
49
 
50
  Returns:
51
+ Dict[str, float]: A dictionary where keys are class names, and values are probabilities.
52
  """
53
  try:
54
  image = image.convert("RGB")
55
  image_tensor = preprocess(image).unsqueeze(0)
 
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  image_tensor = image_tensor.to(device)
58
  model.to(device)
59
+
60
  with torch.no_grad():
61
  output = model(image_tensor)
62
  scores = torch.nn.functional.softmax(output[0], dim=0)
63
+
64
  predictions = {label: float(score) for label, score in zip(labels, scores)}
 
65
  sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
66
+
67
  return sorted_predictions
68
  except Exception as e:
69
  return {"error": str(e)}
 
73
  "🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
74
  "This project utilizes a deep learning model to classify skin diseases based on an uploaded image.\n\n"
75
  "### 🚀 How to Use:\n\n"
76
+ "1️⃣ Upload or take a photo of the affected skin area.\n\n"
77
  "2️⃣ Click the 'Submit' button.\n\n"
78
  "3️⃣ The app will return the probabilities for possible skin conditions.\n\n"
79
  "⚠️ **Important!** The results are for informational purposes only and do not constitute a medical diagnosis.\n\n"
 
81
  "- PyTorch (Lightning)\n"
82
  "- Gradio\n"
83
  "- Hugging Face Spaces\n\n"
84
+ "🔗 Source Code: [GitHub/Hugging Face](https://huggingface.co/spaces/Eraly-ml/Skin-A)"
85
  )
86
 
87
+ # Adding example images
88
+ examples = [
89
+ ["example1.jpg"],
90
+ ["example2.jpg"]
91
+ ]
92
+
93
  interface = gr.Interface(
94
  fn=predict,
95
+ inputs=gr.Image(type="pil", label="Upload Image"),
96
  outputs=gr.Label(num_top_classes=3, label="Prediction"),
97
  title=title,
98
  description=description,
99
+ examples=examples,
100
  theme=gr.themes.Soft()
101
  )
102