INVERTO commited on
Commit
fbb6522
·
verified ·
1 Parent(s): bd734a0

Upload trained bird captioning model, tokenizer, image processor, species mapping, and captions

Browse files
Files changed (1) hide show
  1. README.md +6 -3
README.md CHANGED
@@ -46,8 +46,11 @@ from huggingface_hub import PyTorchModelHubMixin
46
  import torch
47
  from model import BirdCaptioningModel # Save model.py locally
48
 
 
 
 
49
  # Load model
50
- model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200")
51
  image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200")
52
  tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200")
53
  model.eval()
@@ -66,13 +69,13 @@ from PIL import Image
66
 
67
  def predict_bird_image(image_path):
68
  image = Image.open(image_path).convert("RGB")
69
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
70
  with torch.no_grad():
71
  output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4)
72
  _, class_logits = model(pixel_values)
73
  predicted_class_idx = torch.argmax(class_logits, dim=1).item()
74
  confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100
75
- caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
76
  species = species_mapping.get(predicted_class_idx, "Unknown")
77
  return caption, species, confidence
78
 
 
46
  import torch
47
  from model import BirdCaptioningModel # Save model.py locally
48
 
49
+ # Set device
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+
52
  # Load model
53
+ model = BirdCaptioningModel.from_pretrained("INVERTO/bird-captioning-cub200").to(device)
54
  image_processor = ViTImageProcessor.from_pretrained("INVERTO/bird-captioning-cub200")
55
  tokenizer = AutoTokenizer.from_pretrained("INVERTO/bird-captioning-cub200")
56
  model.eval()
 
69
 
70
  def predict_bird_image(image_path):
71
  image = Image.open(image_path).convert("RGB")
72
+ pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
73
  with torch.no_grad():
74
  output_ids = model.base_model.generate(pixel_values, max_length=75, num_beams=4)
75
  _, class_logits = model(pixel_values)
76
  predicted_class_idx = torch.argmax(class_logits, dim=1).item()
77
  confidence = torch.nn.functional.softmax(class_logits, dim=1)[0, predicted_class_idx].item() * 100
78
+ caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
79
  species = species_mapping.get(predicted_class_idx, "Unknown")
80
  return caption, species, confidence
81