And00drew commited on
Commit
12d721d
·
verified ·
1 Parent(s): 58c0dde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -81
app.py CHANGED
@@ -1,100 +1,73 @@
1
- from transformers import AutoImageProcessor, AutoModelForImageClassification
2
- import torch
3
- from PIL import Image
4
  import gradio as gr
 
 
 
5
  import logging
6
- from functools import lru_cache
7
- import os
8
 
9
- # Настройка логирования
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # Проверка доступности GPU
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- logger.info(f"Using device: {device}")
16
 
17
- # Кэширование загрузки модели
18
- @lru_cache(maxsize=1)
19
- def load_model():
20
- logger.info("Loading model and processor...")
 
 
 
 
21
  try:
22
- model = AutoModelForImageClassification.from_pretrained(
23
- "jeemsterri/fish_classification"
24
- ).to(device)
25
- processor = AutoImageProcessor.from_pretrained("jeemsterri/fish_classification", use_fast=True) # Используем быстрый процессор
26
- logger.info("Model loaded successfully")
27
- return model, processor
28
- except Exception as e:
29
- logger.error(f"Error loading model: {str(e)}")
30
- raise
31
-
32
- # Загрузка модели
33
- try:
34
- model, processor = load_model()
35
- except Exception as e:
36
- logger.error(f"Failed to load model: {str(e)}")
37
- raise
38
 
39
- def classify_image(image):
40
- try:
41
- # Конвертация numpy array в PIL Image
42
- if not isinstance(image, Image.Image):
43
- image = Image.fromarray(image)
44
-
45
- logger.info("Processing image...")
46
-
47
- # Преобразование изображения
48
- inputs = processor(images=image, return_tensors="pt").to(device)
49
-
50
- # Предсказание
51
- with torch.no_grad():
52
- outputs = model(**inputs)
53
 
54
- # Обработка результатов
55
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
56
- confidence, predicted_class = torch.max(probs, dim=1)
57
- top_classes = torch.topk(probs, 3)
58
 
59
- # Формирование результата
60
- result = {
61
- "label": model.config.id2label[predicted_class.item()],
62
- "confidence": float(confidence),
63
- "top_classes": [
64
- {
65
- "label": model.config.id2label[i.item()],
66
- "score": float(probs[0][i])
67
- }
68
- for i in top_classes.indices[0]
69
- ]
70
- }
71
-
72
- logger.info(f"Prediction result: {result}")
73
- return result
74
 
75
- except Exception as e:
76
- error_msg = f"Classification error: {str(e)}"
77
- logger.error(error_msg)
78
  return {
79
- "error": error_msg,
80
- "available_labels": list(model.config.id2label.values())[:10] + ["..."]
81
  }
82
 
83
- # Создание интерфейса Gradio
84
- iface = gr.Interface(
85
- fn=classify_image,
 
 
 
 
86
  inputs=gr.Image(type="pil", label="Upload Fish Image"),
87
- outputs=gr.JSON(label="Classification Results"),
88
- title="🐟 Fish Species Classifier",
89
- description="Upload an image of a fish to identify its species",
90
- examples=None, # Убраны примеры, так как файлы отсутствуют
91
- flagging_mode="never", # Заменяет устаревший allow_flagging
92
- cache_examples=False # Отключаем кэширование примеров
93
  )
94
 
95
- # Запуск приложения
96
  if __name__ == "__main__":
97
- iface.launch(
98
- server_name="0.0.0.0",
99
- server_port=7860
100
- )
 
 
 
 
1
  import gradio as gr
2
+ import requests
3
+ from PIL import Image
4
+ import numpy as np
5
  import logging
 
 
6
 
7
+ # Configure logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ # Hugging Face API settings
12
+ HF_API_URL = "https://api-inference.huggingface.co/models/jeemsterri/fish_classification"
13
+ HF_API_KEY = "your_huggingface_api_key" # Replace with your key
14
 
15
+ def classify_fish(image: Image.Image) -> dict:
16
+ """
17
+ Classify a fish image using Hugging Face API or fallback to MobileNet.
18
+ Args:
19
+ image: PIL Image object.
20
+ Returns:
21
+ Dict with predictions or error message.
22
+ """
23
  try:
24
+ # Convert image to bytes for API
25
+ img_bytes = image.tobytes()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # Try Hugging Face API first
28
+ headers = {"Authorization": f"Bearer {HF_API_KEY}"}
29
+ response = requests.post(HF_API_URL, headers=headers, data=img_bytes)
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ if response.status_code == 200:
32
+ predictions = response.json()
33
+ logger.info(f"API response: {predictions}")
34
+ return {"source": "Hugging Face", "predictions": predictions}
35
 
36
+ # Fallback to MobileNet if API fails
37
+ logger.warning(f"API failed (status {response.status_code}), using fallback...")
38
+ import tensorflow as tf
39
+ import tensorflow_hub as hub
40
+
41
+ # Load MobileNet
42
+ model = tf.keras.Sequential([
43
+ hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4")
44
+ ])
45
+ image = image.resize((224, 224)) # MobileNet expects 224x224
46
+ image_array = np.array(image) / 255.0
47
+ image_array = np.expand_dims(image_array, axis=0)
48
+
49
+ predictions = model.predict(image_array)
50
+ top_prediction = tf.keras.applications.mobilenet_v2.decode_predictions(predictions, top=1)[0][0]
51
 
 
 
 
52
  return {
53
+ "source": "MobileNet (Fallback)",
54
+ "predictions": [{"label": top_prediction[1], "score": float(top_prediction[2])}]
55
  }
56
 
57
+ except Exception as e:
58
+ logger.error(f"Classification error: {str(e)}")
59
+ return {"error": str(e)}
60
+
61
+ # Gradio Interface
62
+ interface = gr.Interface(
63
+ fn=classify_fish,
64
  inputs=gr.Image(type="pil", label="Upload Fish Image"),
65
+ outputs=gr.JSON(label="Prediction Results"),
66
+ title="🐟 Fish Classifier",
67
+ description="Upload an image of a fish to see the predicted class probabilities.",
68
+ examples=["salmon.jpg", "tuna.jpg"], # Add example images
69
+ theme="soft"
 
70
  )
71
 
 
72
  if __name__ == "__main__":
73
+ interface.launch(server_name="0.0.0.0", server_port=7860)