Upload trained bird captioning model, tokenizer, image processor, species mapping, and captions
Browse files
README.md
CHANGED
|
@@ -46,8 +46,11 @@ from huggingface_hub import PyTorchModelHubMixin
|
|
| 46 |
import torch
|
| 47 |
from model import BirdCaptioningModel # Save model.py locally
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
# Load model
|
| 50 |
-
model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200")
|
| 51 |
image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200")
|
| 52 |
tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200")
|
| 53 |
model.eval()
|
|
@@ -66,13 +69,13 @@ from PIL import Image
|
|
| 66 |
|
| 67 |
def predict_bird_image(image_path):
|
| 68 |
image = Image.open(image_path).convert("RGB")
|
| 69 |
-
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
| 70 |
with torch.no_grad():
|
| 71 |
output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4)
|
| 72 |
_, class_logits = model(pixel_values)
|
| 73 |
predicted_class_idx = torch.argmax(class_logits, dim=1).item()
|
| 74 |
confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100
|
| 75 |
-
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 76 |
species = species_mapping.get(predicted_class_idx, "Unknown")
|
| 77 |
return caption, species, confidence
|
| 78 |
|
|
|
|
| 46 |
import torch
|
| 47 |
from model import BirdCaptioningModel # Save model.py locally
|
| 48 |
|
| 49 |
+
# Set device
|
| 50 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 51 |
+
|
| 52 |
# Load model
|
| 53 |
+
model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200").to(device)
|
| 54 |
image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200")
|
| 55 |
tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200")
|
| 56 |
model.eval()
|
|
|
|
| 69 |
|
| 70 |
def predict_bird_image(image_path):
|
| 71 |
image = Image.open(image_path).convert("RGB")
|
| 72 |
+
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
|
| 73 |
with torch.no_grad():
|
| 74 |
output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4)
|
| 75 |
_, class_logits = model(pixel_values)
|
| 76 |
predicted_class_idx = torch.argmax(class_logits, dim=1).item()
|
| 77 |
confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100
|
| 78 |
+
caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
| 79 |
species = species_mapping.get(predicted_class_idx, "Unknown")
|
| 80 |
return caption, species, confidence
|
| 81 |
|