| import gradio as gr |
| import timm |
| import torch |
| from cods.classif.cp import ClassificationConformalizer |
| from cods.classif.data import ClassificationDataset |
| from cods.classif.data.predictions import ClassificationPredictions |
| from cods.classif.models import ClassificationModel |
| from datasets import load_dataset |
|
|
| |
| from PIL import Image |
|
|
| |
| from dataset import DatasetWrapper |
|
|
| DATASETS = { |
| "miniimagenet": "timm/mini-imagenet", |
| "imagenette": "frgfm/imagenette", |
| "imagenet": "imagenet-1k", |
| } |
|
|
| MODELS = { |
| "miniimagenet": [ |
| "QuentinJG/ResNet18-miniimagenet", |
| "shahrukhx01/vit-base-patch16-miniimagenet", |
| ], |
| } |
|
|
| classification_conformalizer = ClassificationConformalizer(method="lac", preprocess="softmax") |
|
|
|
|
| def calibrate(dataset_name, model_name): |
| global model |
| |
| |
| |
| |
| model_name = "resnet34" |
| global pretrained_resnet_34 |
| pretrained_resnet_34 = timm.create_model(model_name, pretrained=True) |
| classifier = ClassificationModel(model=pretrained_resnet_34, model_name=model_name) |
| global dataset |
| dataset = load_dataset(DATASETS[dataset_name], split="validation") |
| dataset = DatasetWrapper(dataset) |
|
|
| val_preds = classifier.build_predictions( |
| dataset, |
| dataset_name=dataset_name, |
| split_name="cal", |
| batch_size=512, |
| shuffle=False, |
| ) |
| classification_conformalizer.calibrate(val_preds, alpha=0.1) |
| return f"Calibrated on {dataset_name} with model {model_name}" |
|
|
|
|
| def predict_image(img): |
| img_old = img.copy() |
| img = dataset.transforms(img).unsqueeze(0) |
| pred = pretrained_resnet_34(img) |
| inference_pred = ClassificationPredictions( |
| dataset_name="uploaded", |
| split_name="test", |
| image_paths=[None], |
| idx_to_cls=dataset.idx_to_cls, |
| true_cls=torch.tensor([-1]), |
| pred_cls=pred, |
| ) |
|
|
| result = classification_conformalizer.conformalize(inference_pred) |
| list_of_classes = [dataset.idx_to_cls[i] for i in result[0].detach().numpy()] |
| result = f"Predicted classes with 90% confidence: {list_of_classes}" |
| return img_old, result |
|
|
|
|
| |
| |
|
|
|
|
| def main_function(lbd, img): |
| |
| |
| |
| |
| |
| |
| |
| |
| new_img = img |
| return new_img |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Image Classification with Conformal Prediction") |
| gr.Markdown("## Upload an image and get conformalized classification predictions.") |
|
|
| with gr.Row(): |
| dataset_dropdown = gr.Dropdown( |
| choices=DATASETS.keys(), label="Select Dataset", value=list(DATASETS.keys())[0] |
| ) |
| model_dropdown = gr.Dropdown( |
| choices=MODELS[dataset_dropdown.value], |
| label="Select Model", |
| value=MODELS[dataset_dropdown.value][0], |
| ) |
|
|
| calibrate_btn = gr.Button("Calibrate") |
| status_text = gr.Textbox(label="Status", interactive=False) |
|
|
| gr.Markdown("---") |
|
|
| with gr.Row(): |
| input_image = gr.Image(label="Upload Image", type="pil") |
| output_image = gr.Image(label="Processed Image") |
|
|
| predict_btn = gr.Button("Predict") |
| result_text = gr.Textbox(label="Prediction Result") |
|
|
| |
| calibrate_btn.click( |
| fn=calibrate, inputs=[dataset_dropdown, model_dropdown], outputs=status_text |
| ) |
|
|
| predict_btn.click(fn=predict_image, inputs=input_image, outputs=[output_image, result_text]) |
|
|
| demo.launch() |
|
|