Spaces:
Sleeping
Sleeping
| import json | |
| import gradio, torch | |
| from safetensors.torch import load_file | |
| from torch import nn, cuda | |
| from torchvision.models import resnet50, ResNet50_Weights | |
| from torchvision.transforms import v2 | |
| def predict(image): | |
| image = image.convert("RGB") | |
| image = transforms(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| prediction = outputs.argmax(dim=1).item() | |
| probabilities = nn.functional.softmax(outputs, dim=1)[0].tolist() | |
| probabilities = dict(zip(labels.values(), probabilities)) | |
| label = labels[prediction] | |
| json_component = json.dumps({ | |
| "label": label, | |
| "probability": probabilities, | |
| }) | |
| class_a, class_b = labels.values() | |
| percentage_a = probabilities[class_a]*100 | |
| percentage_b = probabilities[class_b]*100 | |
| return ( | |
| json_component, | |
| f"Kemungkinan dibuat oleh {class_a.upper()}: {percentage_a:.4f}%", | |
| f"Kemungkinan dibuat oleh {class_b.upper()}: {percentage_b:.4f}%", | |
| ) | |
| torch.hub.set_dir("base_models") | |
| labels = { | |
| 0: "ai", | |
| 1: "seniman" | |
| } | |
| device = "cuda" if cuda.is_available() else "cpu" | |
| weight = ResNet50_Weights.IMAGENET1K_V2 | |
| transforms = v2.Compose([ | |
| v2.PILToTensor(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Resize(256), | |
| v2.Normalize( | |
| mean=weight.transforms().mean, | |
| std=weight.transforms().std, | |
| ), | |
| ]) | |
| model = resnet50(weights=weight) | |
| model.fc = nn.Linear(model.fc.in_features, 2) | |
| state_dict = load_file( | |
| f"main_models/model.safetensors", device=device | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| demo = gradio.Interface( | |
| fn=predict, | |
| inputs=[gradio.Image(label="Gambar", type="pil")], | |
| outputs=[ | |
| gradio.JSON(label="Hasil Prediksi"), | |
| gradio.Label(label="Kemungkinan AI"), | |
| gradio.Label(label="Kemungkinan Seniman"), | |
| ], | |
| flagging_mode="never", | |
| submit_btn="Prediksi" | |
| ) | |
| demo.queue(default_concurrency_limit=1).launch() | |