๐Ÿฆˆ vit-shark-model: Vision Transformer for Shark Species Classification

A fine-tuned Vision Transformer (ViT-B/16) model trained to classify 14 different shark species using image data.
This model leverages the power of transfer learning from the Hugging Face google/vit-base-patch16-224-in21k backbone.


๐Ÿ“Š Model Summary

Item Details
Base Model google/vit-base-patch16-224-in21k
Fine-tuned on 14 Shark Classes
Train Samples 1395
Validation Samples 295
Test Samples 313
Framework PyTorch + Hugging Face Transformers
Processor ViTImageProcessor
Epochs Trained 5
Test Accuracy 88.18%
Test Loss 0.7288

๐Ÿ‹ Classes (14 Total)

  • 'blacktip_Shark'
  • 'blue_shark'
  • 'bull_Shark'
  • 'great_White_Shark'
  • 'grey_reef_shark'
  • 'hammerhead_Shark'
  • 'mako_Shark'
  • 'sand_tiger_shark'
  • 'sevengill_shark'
  • 'silky_shark'
  • 'silvertip_shark'
  • 'tiger_Shark'
  • 'whale_Shark'
  • 'whitetip_Shark'

๐Ÿงพ Dataset Summary

Dataset was preprocessed using ViTImageProcessor with resizing, normalization, and label mapping.


๐Ÿ“ˆ Training Metrics

Epoch Training Loss Validation Loss Accuracy
1 1.8651 1.7818 69.83%
2 1.1247 1.2082 85.08%
3 0.7594 0.9216 86.78%
4 0.4500 0.7947 87.12%
5 0.3931 0.7611 88.81% โœ…

๐Ÿ“Š Training Curves

Below are the plots showing the loss and accuracy over training epochs.

Training and Evaluation Metrics

โœ… Use Cases

Marine biology research

Wildlife conservation

Shark detection for underwater robotics

Educational tools for biodiversity studies


๐Ÿšซ Limitations

Model may misclassify species with similar appearances.

Performance could degrade on blurry, low-resolution, or occluded images.

Fine-tuned on a relatively small dataset (~2000 images total).


(How to Use)

from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image as PILImage
import torch

# ุงู„ู…ุณุงุฑ ุงู„ุฐูŠ ุชู… ุญูุธ ุงู„ู†ู…ูˆุฐุฌ ููŠู‡
model_path = "./vit-shark-model"

# ุชุญู…ูŠู„ ุงู„ู…ุนุงู„ุฌ ูˆุงู„ู†ู…ูˆุฐุฌ
image_processor = ViTImageProcessor.from_pretrained(model_path)
model = ViTForImageClassification.from_pretrained(model_path)

# (ุงุฎุชูŠุงุฑูŠ) ู†ู‚ู„ ุงู„ู†ู…ูˆุฐุฌ ู„ูˆุญุฏุฉ ุงู„ู…ุนุงู„ุฌุฉ ุงู„ุฑุณูˆู…ูŠุฉ (GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # ูˆุถุน ุงู„ุชู‚ูŠูŠู…

# ู…ุซุงู„: ู…ุนุงู„ุฌุฉ ุตูˆุฑุฉ ูˆุชูˆู‚ุนู‡ุง
image = PILImage.open("path/to/your/shark_image.jpg").convert("RGB")
inputs = image_processor(images=image, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]

print(f"Predicted class: {predicted_label}")
!pip install gradio transformers torch torchvision

import gradio as gr
from transformers import pipeline
import torch
from torchvision import transforms
from PIL import Image

# Load the model
try:
    classifier = pipeline("image-classification", model="HatemMoushir/DeepShark-ViT-Hatem-V1")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please ensure you have access to the model or that the model name is correct.")
    # Fallback or exit if model can't be loaded
    exit()

# Define a function to make predictions
def predict_shark_species(image):
    if image is None:
        return "Please upload an image."

    # The pipeline handles preprocessing, but sometimes explicit conversion helps
    # Ensure the image is in RGB format if it's not already
    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Make prediction
    # The output is a list of dictionaries, e.g., [{'score': 0.99, 'label': 'Great White Shark'}]
    predictions = classifier(image)

    # Format the output
    if predictions:
        # Get the top prediction
        top_prediction = predictions[0]
        label = top_prediction['label'].replace('_', ' ').title() # Format label nicely
        score = top_prediction['score'] * 100 # Convert to percentage

        if score < 50:
           return "Prediction: Unknown (Low confidence)"
        else:
        return f"Prediction: **{label}**\nConfidence: **{score:.2f}%**"
    else:
        return "No prediction could be made."

# Create the Gradio interface
iface = gr.Interface(
    fn=predict_shark_species,
    inputs=gr.Image(type="pil", label="Upload Shark Image"),
    outputs="markdown",
    title="๐Ÿฆˆ DeepShark-ViT-Hatem-V1: Shark Species Classifier",
    description="Upload an image of a shark to get a prediction of its species using the HatemMoushir/DeepShark-ViT-Hatem-V1 model.",
    examples=[
        # You can add example image paths here if you have them locally
        # e.g., ["path/to/your/shark_example1.jpg"]
    ]
)

# Launch the interface
if __name__ == "__main__":
    print("Starting Gradio interface...")
    iface.launch(share=True) # Set share=True to get a public link (useful for sharing)

(Acknowledgements)

Special thanks to the scientific community and the developers of the Hugging Face Transformers and PyTorch libraries for providing the essential tools needed to build and train this model. We also extend our sincere gratitude for the support and guidance provided by ChatGPT and Gemini throughout various stages of this project, which significantly contributed to its success. Finally, we thank the Colab platform for offering a suitable environment for training the model.


license:

cc-by-nc-sa-4.0

Author

Hatem Moushir h_moushir@hotmail.com

Downloads last month
6
Safetensors
Model size
85.8M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support