File size: 2,574 Bytes
366511c
12d721d
 
 
8c25b49
366511c
12d721d
8c25b49
 
 
12d721d
 
 
8c25b49
12d721d
 
 
 
 
 
 
 
8c25b49
12d721d
 
366511c
12d721d
 
 
8c25b49
12d721d
 
 
 
4d5e8b3
12d721d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c25b49
 
12d721d
 
8c25b49
366511c
12d721d
 
 
 
 
 
 
a33cf73
12d721d
 
 
 
 
4d5e8b3
 
 
12d721d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import requests
from PIL import Image
import numpy as np
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Hugging Face API settings
HF_API_URL = "https://api-inference.huggingface.co/models/jeemsterri/fish_classification"
HF_API_KEY = "your_huggingface_api_key"  # Replace with your key

def classify_fish(image: Image.Image) -> dict:
    """
    Classify a fish image using Hugging Face API or fallback to MobileNet.
    Args:
        image: PIL Image object.
    Returns:
        Dict with predictions or error message.
    """
    try:
        # Convert image to bytes for API
        img_bytes = image.tobytes()

        # Try Hugging Face API first
        headers = {"Authorization": f"Bearer {HF_API_KEY}"}
        response = requests.post(HF_API_URL, headers=headers, data=img_bytes)
        
        if response.status_code == 200:
            predictions = response.json()
            logger.info(f"API response: {predictions}")
            return {"source": "Hugging Face", "predictions": predictions}
        
        # Fallback to MobileNet if API fails
        logger.warning(f"API failed (status {response.status_code}), using fallback...")
        import tensorflow as tf
        import tensorflow_hub as hub

        # Load MobileNet
        model = tf.keras.Sequential([
            hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4")
        ])
        image = image.resize((224, 224))  # MobileNet expects 224x224
        image_array = np.array(image) / 255.0
        image_array = np.expand_dims(image_array, axis=0)

        predictions = model.predict(image_array)
        top_prediction = tf.keras.applications.mobilenet_v2.decode_predictions(predictions, top=1)[0][0]
        
        return {
            "source": "MobileNet (Fallback)",
            "predictions": [{"label": top_prediction[1], "score": float(top_prediction[2])}]
        }

    except Exception as e:
        logger.error(f"Classification error: {str(e)}")
        return {"error": str(e)}

# Gradio Interface
interface = gr.Interface(
    fn=classify_fish,
    inputs=gr.Image(type="pil", label="Upload Fish Image"),
    outputs=gr.JSON(label="Prediction Results"),
    title="🐟 Fish Classifier",
    description="Upload an image of a fish to see the predicted class probabilities.",
    examples=["salmon.jpg", "tuna.jpg"],  # Add example images
    theme="soft"
)

if __name__ == "__main__":
    interface.launch(server_name="0.0.0.0", server_port=7860)