| import argparse |
| import os |
| import warnings |
| from transformers import SwinForImageClassification, AutoImageProcessor |
| from PIL import Image |
| import torch |
|
|
| warnings.filterwarnings("ignore") |
| os.environ["TRANSFORMERS_VERBOSITY"] = "error" |
|
|
| def classify_image(image_path): |
| model = SwinForImageClassification.from_pretrained("./") |
| processor = AutoImageProcessor.from_pretrained("./", use_fast=False) |
|
|
| image = Image.open(image_path).convert("RGB") |
| inputs = processor(images=image, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
| predicted_class = torch.argmax(logits, dim=1).item() |
|
|
| label = model.config.id2label[predicted_class] |
| return "Real" if label == "human" else "Fake" |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Detecta se uma imagem é real ou gerada por IA.") |
| parser.add_argument("--image", type=str, required=True, help="Caminho para a imagem a ser analisada.") |
| |
| args = parser.parse_args() |
| result = classify_image(args.image) |
| print(result) |