Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from peft import PeftModel | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from torchvision.transforms import ( | |
| CenterCrop, | |
| Compose, | |
| Normalize, | |
| RandomHorizontalFlip, | |
| RandomResizedCrop, | |
| Resize, | |
| ToTensor, | |
| ) | |
| model_name = 'google/vit-large-patch16-224' | |
| adapter = 'monsoon-nlp/eyegazer-vit-binary' | |
| image_processor = AutoImageProcessor.from_pretrained(model_name) | |
| normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) | |
| train_transforms = Compose( | |
| [ | |
| RandomResizedCrop(image_processor.size["height"]), | |
| RandomHorizontalFlip(), | |
| ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| val_transforms = Compose( | |
| [ | |
| Resize(image_processor.size["height"]), | |
| CenterCrop(image_processor.size["height"]), | |
| ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| model = AutoModelForImageClassification.from_pretrained( | |
| model_name, | |
| ignore_mismatched_sizes=True, | |
| num_labels=2, | |
| ) | |
| lora_model = PeftModel.from_pretrained(model, adapter) | |
| def query(img): | |
| pimg = val_transforms(img.convert("RGB")) | |
| batch = pimg.unsqueeze(0) | |
| op = lora_model(batch) | |
| vals = op.logits.tolist()[0] | |
| if vals[0] > vals[1]: | |
| return "Predicted unaffected" | |
| else: | |
| return "Predicted affected to some degree" | |
| iface = gr.Interface( | |
| fn=query, | |
| examples=[ | |
| # os.path.join(os.path.dirname(__file__), "images/i1.png"), | |
| os.path.join(os.path.dirname(__file__), "images/0a09aa7356c0.png"), | |
| os.path.join(os.path.dirname(__file__), "images/0a4e1a29ffff.png"), | |
| os.path.join(os.path.dirname(__file__), "images/0c43c79e8cfb.png"), | |
| os.path.join(os.path.dirname(__file__), "images/0c7e82daf5a0.png"), | |
| ], | |
| inputs=[ | |
| gr.Image( | |
| image_mode='RGB', | |
| sources=['upload', 'clipboard'], | |
| type='pil', | |
| label='Input Fundus Camera Image', | |
| show_label=True, | |
| ), | |
| ], | |
| outputs=[ | |
| gr.Markdown(value="", label="Predicted label"), | |
| ], | |
| title="ViT retinopathy model", | |
| description="Diabetic retinopathy model trained on APTOS 2019 dataset; demonstration, not medical dvice", | |
| allow_flagging="never", | |
| ) | |
| iface.launch() | |