And00drew commited on
Commit
8c25b49
·
verified ·
1 Parent(s): 4d5e8b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -20
app.py CHANGED
@@ -2,43 +2,109 @@ from transformers import AutoImageProcessor, AutoModelForImageClassification
2
  import torch
3
  from PIL import Image
4
  import gradio as gr
 
 
5
 
6
- # Загрузка модели и процессора
7
- model = AutoModelForImageClassification.from_pretrained("jeemsterri/fish_classification")
8
- processor = AutoImageProcessor.from_pretrained("jeemsterri/fish_classification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Функция классификации
11
  def classify_image(image):
12
  try:
13
- # Преобразование изображения и предсказание
14
- inputs = processor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
15
  with torch.no_grad():
16
  outputs = model(**inputs)
 
 
17
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
18
  confidence, predicted_class = torch.max(probs, dim=1)
19
- label = model.config.id2label[predicted_class.item()]
20
 
21
- # Формат ответа для API
22
- return {
23
- "label": label,
24
  "confidence": float(confidence),
25
  "top_classes": [
26
- {"label": model.config.id2label[i], "score": float(probs[0][i])}
27
- for i in torch.topk(probs, 3).indices[0]
 
 
 
28
  ]
29
  }
 
 
 
 
30
  except Exception as e:
31
- return {"error": str(e)}
 
 
 
 
 
32
 
33
- # Gradio Interface
34
  iface = gr.Interface(
35
  fn=classify_image,
36
- inputs=gr.Image(type="pil"),
37
- outputs=gr.JSON(),
38
- title="Fish Species Classifier",
39
- description="Upload a fish image to classify its species."
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
 
42
- # Запуск с CORS (если нужно)
43
  if __name__ == "__main__":
44
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
2
  import torch
3
  from PIL import Image
4
  import gradio as gr
5
+ import logging
6
+ from functools import lru_cache
7
 
8
+ # Настройка логирования
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Проверка доступности GPU
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ logger.info(f"Using device: {device}")
15
+
16
+ # Кэширование загрузки модели для ускорения последующих запросов
17
+ @lru_cache(maxsize=1)
18
+ def load_model():
19
+ logger.info("Loading model and processor...")
20
+ try:
21
+ model = AutoModelForImageClassification.from_pretrained(
22
+ "jeemsterri/fish_classification"
23
+ ).to(device)
24
+ processor = AutoImageProcessor.from_pretrained("jeemsterri/fish_classification")
25
+ logger.info("Model loaded successfully")
26
+ return model, processor
27
+ except Exception as e:
28
+ logger.error(f"Error loading model: {str(e)}")
29
+ raise
30
+
31
+ # Загрузка модели при старте
32
+ try:
33
+ model, processor = load_model()
34
+ except Exception as e:
35
+ logger.error(f"Failed to load model: {str(e)}")
36
+ raise
37
 
 
38
  def classify_image(image):
39
  try:
40
+ # Проверка входного изображения
41
+ if not isinstance(image, Image.Image):
42
+ image = Image.fromarray(image)
43
+
44
+ logger.info("Processing image...")
45
+
46
+ # Преобразование изображения
47
+ inputs = processor(images=image, return_tensors="pt").to(device)
48
+
49
+ # Предсказание
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
+
53
+ # Обработка результатов
54
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
55
  confidence, predicted_class = torch.max(probs, dim=1)
56
+ top_classes = torch.topk(probs, 3)
57
 
58
+ # Формирование результата
59
+ result = {
60
+ "label": model.config.id2label[predicted_class.item()],
61
  "confidence": float(confidence),
62
  "top_classes": [
63
+ {
64
+ "label": model.config.id2label[i.item()],
65
+ "score": float(probs[0][i])
66
+ }
67
+ for i in top_classes.indices[0]
68
  ]
69
  }
70
+
71
+ logger.info(f"Prediction result: {result}")
72
+ return result
73
+
74
  except Exception as e:
75
+ error_msg = f"Classification error: {str(e)}"
76
+ logger.error(error_msg)
77
+ return {
78
+ "error": error_msg,
79
+ "available_labels": list(model.config.id2label.values())[:10] + ["..."]
80
+ }
81
 
82
+ # Создание интерфейса Gradio с улучшенным UI
83
  iface = gr.Interface(
84
  fn=classify_image,
85
+ inputs=gr.Image(
86
+ type="pil",
87
+ label="Upload Fish Image",
88
+ sources=["upload", "webcam", "clipboard"]
89
+ ),
90
+ outputs=gr.JSON(
91
+ label="Classification Results"
92
+ ),
93
+ title="🐟 Fish Species Classifier",
94
+ description="Upload an image of a fish to identify its species using AI",
95
+ examples=[
96
+ ["salmon.jpg"],
97
+ ["trout.jpg"]
98
+ ],
99
+ allow_flagging="never",
100
+ theme=gr.themes.Soft()
101
  )
102
 
103
+ # Конфигурация запуска
104
  if __name__ == "__main__":
105
+ iface.launch(
106
+ server_name="0.0.0.0",
107
+ server_port=7860,
108
+ enable_queue=True,
109
+ share=False
110
+ )