And00drew commited on
Commit
4d5e8b3
·
verified ·
1 Parent(s): c9189c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -13
app.py CHANGED
@@ -3,21 +3,42 @@ import torch
3
  from PIL import Image
4
  import gradio as gr
5
 
6
- # Загрузка модели
7
  model = AutoModelForImageClassification.from_pretrained("jeemsterri/fish_classification")
8
  processor = AutoImageProcessor.from_pretrained("jeemsterri/fish_classification")
9
 
10
- # Предсказание
11
  def classify_image(image):
12
- inputs = processor(images=image, return_tensors="pt")
13
- outputs = model(**inputs)
14
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
15
- confidence, predicted_class = torch.max(probs, dim=1)
16
- label = model.config.id2label[predicted_class.item()]
17
- return {label: float(confidence)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- gr.Interface(fn=classify_image,
20
- inputs=gr.Image(type="pil"),
21
- outputs=gr.Label(num_top_classes=3),
22
- title="Fish Species Classifier",
23
- description="Upload a fish image and get the predicted species.").launch()
 
 
 
 
 
 
 
 
3
  from PIL import Image
4
  import gradio as gr
5
 
6
+ # Загрузка модели и процессора
7
  model = AutoModelForImageClassification.from_pretrained("jeemsterri/fish_classification")
8
  processor = AutoImageProcessor.from_pretrained("jeemsterri/fish_classification")
9
 
10
+ # Функция классификации
11
  def classify_image(image):
12
+ try:
13
+ # Преобразование изображения и предсказание
14
+ inputs = processor(images=image, return_tensors="pt")
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
18
+ confidence, predicted_class = torch.max(probs, dim=1)
19
+ label = model.config.id2label[predicted_class.item()]
20
+
21
+ # Формат ответа для API
22
+ return {
23
+ "label": label,
24
+ "confidence": float(confidence),
25
+ "top_classes": [
26
+ {"label": model.config.id2label[i], "score": float(probs[0][i])}
27
+ for i in torch.topk(probs, 3).indices[0]
28
+ ]
29
+ }
30
+ except Exception as e:
31
+ return {"error": str(e)}
32
 
33
+ # Gradio Interface
34
+ iface = gr.Interface(
35
+ fn=classify_image,
36
+ inputs=gr.Image(type="pil"),
37
+ outputs=gr.JSON(),
38
+ title="Fish Species Classifier",
39
+ description="Upload a fish image to classify its species."
40
+ )
41
+
42
+ # Запуск с CORS (если нужно)
43
+ if __name__ == "__main__":
44
+ iface.launch(server_name="0.0.0.0", server_port=7860)