ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders
Paper • 2301.00808 • Published • 1
How to use computervisionpro/convnextv2-real-fake with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("image-classification", model="computervisionpro/convnextv2-real-fake")
pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png") # Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification
processor = AutoImageProcessor.from_pretrained("computervisionpro/convnextv2-real-fake")
model = AutoModelForImageClassification.from_pretrained("computervisionpro/convnextv2-real-fake")This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
The model is fine-tuned on ConvNext V2 model.
This finetuned model can be used for text classification. It has been trained to classify real and fake images.
import os
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
MODEL_ID = "computervisionpro/convnextv2-real-fake"
def predict(image_path, model_id=MODEL_ID):
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
# hf_token = os.getenv("HF_TOKEN") or None
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
model.to(device)
model.eval()
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.inference_mode():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0]
pred_id = int(torch.argmax(probs).item())
label = model.config.id2label.get(pred_id, str(pred_id))
confidence = float(probs[pred_id].item())
return {
"image": image_path,
"model": model_id,
"prediction": label,
"confidence": confidence,
"probabilities": {
model.config.id2label.get(i, str(i)): float(prob.item())
for i, prob in enumerate(probs)
},
}
result = predict("./dataset/test/fake/fake_1006.jpg")
print()
print(result)
Base model
facebook/convnextv2-tiny-1k-224