๐ฆ vit-shark-model: Vision Transformer for Shark Species Classification
A fine-tuned Vision Transformer (ViT-B/16) model trained to classify 14 different shark species using image data.
This model leverages the power of transfer learning from the Hugging Face google/vit-base-patch16-224-in21k backbone.
๐ Model Summary
| Item | Details |
|---|---|
| Base Model | google/vit-base-patch16-224-in21k |
| Fine-tuned on | 14 Shark Classes |
| Train Samples | 1395 |
| Validation Samples | 295 |
| Test Samples | 313 |
| Framework | PyTorch + Hugging Face Transformers |
| Processor | ViTImageProcessor |
| Epochs Trained | 5 |
| Test Accuracy | 88.18% |
| Test Loss | 0.7288 |
๐ Classes (14 Total)
- 'blacktip_Shark'
- 'blue_shark'
- 'bull_Shark'
- 'great_White_Shark'
- 'grey_reef_shark'
- 'hammerhead_Shark'
- 'mako_Shark'
- 'sand_tiger_shark'
- 'sevengill_shark'
- 'silky_shark'
- 'silvertip_shark'
- 'tiger_Shark'
- 'whale_Shark'
- 'whitetip_Shark'
๐งพ Dataset Summary
Dataset was preprocessed using ViTImageProcessor with resizing, normalization, and label mapping.
๐ Training Metrics
| Epoch | Training Loss | Validation Loss | Accuracy |
|---|---|---|---|
| 1 | 1.8651 | 1.7818 | 69.83% |
| 2 | 1.1247 | 1.2082 | 85.08% |
| 3 | 0.7594 | 0.9216 | 86.78% |
| 4 | 0.4500 | 0.7947 | 87.12% |
| 5 | 0.3931 | 0.7611 | 88.81% โ |
๐ Training Curves
Below are the plots showing the loss and accuracy over training epochs.
โ Use Cases
Marine biology research
Wildlife conservation
Shark detection for underwater robotics
Educational tools for biodiversity studies
๐ซ Limitations
Model may misclassify species with similar appearances.
Performance could degrade on blurry, low-resolution, or occluded images.
Fine-tuned on a relatively small dataset (~2000 images total).
(How to Use)
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image as PILImage
import torch
# ุงูู
ุณุงุฑ ุงูุฐู ุชู
ุญูุธ ุงููู
ูุฐุฌ ููู
model_path = "./vit-shark-model"
# ุชุญู
ูู ุงูู
ุนุงูุฌ ูุงููู
ูุฐุฌ
image_processor = ViTImageProcessor.from_pretrained(model_path)
model = ViTForImageClassification.from_pretrained(model_path)
# (ุงุฎุชูุงุฑู) ููู ุงููู
ูุฐุฌ ููุญุฏุฉ ุงูู
ุนุงูุฌุฉ ุงูุฑุณูู
ูุฉ (GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # ูุถุน ุงูุชูููู
# ู
ุซุงู: ู
ุนุงูุฌุฉ ุตูุฑุฉ ูุชููุนูุง
image = PILImage.open("path/to/your/shark_image.jpg").convert("RGB")
inputs = image_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
print(f"Predicted class: {predicted_label}")
!pip install gradio transformers torch torchvision
import gradio as gr
from transformers import pipeline
import torch
from torchvision import transforms
from PIL import Image
# Load the model
try:
classifier = pipeline("image-classification", model="HatemMoushir/DeepShark-ViT-Hatem-V1")
except Exception as e:
print(f"Error loading model: {e}")
print("Please ensure you have access to the model or that the model name is correct.")
# Fallback or exit if model can't be loaded
exit()
# Define a function to make predictions
def predict_shark_species(image):
if image is None:
return "Please upload an image."
# The pipeline handles preprocessing, but sometimes explicit conversion helps
# Ensure the image is in RGB format if it's not already
if image.mode != 'RGB':
image = image.convert('RGB')
# Make prediction
# The output is a list of dictionaries, e.g., [{'score': 0.99, 'label': 'Great White Shark'}]
predictions = classifier(image)
# Format the output
if predictions:
# Get the top prediction
top_prediction = predictions[0]
label = top_prediction['label'].replace('_', ' ').title() # Format label nicely
score = top_prediction['score'] * 100 # Convert to percentage
if score < 50:
return "Prediction: Unknown (Low confidence)"
else:
return f"Prediction: **{label}**\nConfidence: **{score:.2f}%**"
else:
return "No prediction could be made."
# Create the Gradio interface
iface = gr.Interface(
fn=predict_shark_species,
inputs=gr.Image(type="pil", label="Upload Shark Image"),
outputs="markdown",
title="๐ฆ DeepShark-ViT-Hatem-V1: Shark Species Classifier",
description="Upload an image of a shark to get a prediction of its species using the HatemMoushir/DeepShark-ViT-Hatem-V1 model.",
examples=[
# You can add example image paths here if you have them locally
# e.g., ["path/to/your/shark_example1.jpg"]
]
)
# Launch the interface
if __name__ == "__main__":
print("Starting Gradio interface...")
iface.launch(share=True) # Set share=True to get a public link (useful for sharing)
(Acknowledgements)
Special thanks to the scientific community and the developers of the Hugging Face Transformers and PyTorch libraries for providing the essential tools needed to build and train this model. We also extend our sincere gratitude for the support and guidance provided by ChatGPT and Gemini throughout various stages of this project, which significantly contributed to its success. Finally, we thank the Colab platform for offering a suitable environment for training the model.
license:
cc-by-nc-sa-4.0
Author
Hatem Moushir h_moushir@hotmail.com
- Downloads last month
- 6
