CSATv2 / example_2.py
sosigikiller's picture
change_folder
e4d78d5
raw
history blame
613 Bytes
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
processor = AutoImageProcessor.from_pretrained("Hyunil/CSATv2", trust_remote_code=True)
model = AutoModelForImageClassification.from_pretrained("Hyunil/CSATv2", trust_remote_code=True)
url = "https://images.unsplash.com/photo-1516116216624-53e697fedbea"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
inputs = processor(image, return_tensors="pt")
outputs = model(**inputs)
probs = outputs.logits.softmax(dim=-1)
top_prob, top_idx = probs.topk(5)
print(top_idx, top_prob)