import gradio as gr import requests from PIL import Image import numpy as np import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Hugging Face API settings HF_API_URL = "https://api-inference.huggingface.co/models/jeemsterri/fish_classification" HF_API_KEY = "your_huggingface_api_key" # Replace with your key def classify_fish(image: Image.Image) -> dict: """ Classify a fish image using Hugging Face API or fallback to MobileNet. Args: image: PIL Image object. Returns: Dict with predictions or error message. """ try: # Convert image to bytes for API img_bytes = image.tobytes() # Try Hugging Face API first headers = {"Authorization": f"Bearer {HF_API_KEY}"} response = requests.post(HF_API_URL, headers=headers, data=img_bytes) if response.status_code == 200: predictions = response.json() logger.info(f"API response: {predictions}") return {"source": "Hugging Face", "predictions": predictions} # Fallback to MobileNet if API fails logger.warning(f"API failed (status {response.status_code}), using fallback...") import tensorflow as tf import tensorflow_hub as hub # Load MobileNet model = tf.keras.Sequential([ hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4") ]) image = image.resize((224, 224)) # MobileNet expects 224x224 image_array = np.array(image) / 255.0 image_array = np.expand_dims(image_array, axis=0) predictions = model.predict(image_array) top_prediction = tf.keras.applications.mobilenet_v2.decode_predictions(predictions, top=1)[0][0] return { "source": "MobileNet (Fallback)", "predictions": [{"label": top_prediction[1], "score": float(top_prediction[2])}] } except Exception as e: logger.error(f"Classification error: {str(e)}") return {"error": str(e)} # Gradio Interface interface = gr.Interface( fn=classify_fish, inputs=gr.Image(type="pil", label="Upload Fish Image"), outputs=gr.JSON(label="Prediction Results"), title="🐟 Fish Classifier", description="Upload an image of a fish to see the predicted class probabilities.", examples=["salmon.jpg", "tuna.jpg"], # Add example images theme="soft" ) if __name__ == "__main__": interface.launch(server_name="0.0.0.0", server_port=7860)