Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| from model import load_model | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from thop import profile | |
| import io | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| models_cache = {} | |
| transform = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) | |
| ]) | |
| class_names = [ | |
| 'Alzheimer Disease', | |
| 'Mild Alzheimer Risk', | |
| 'Moderate Alzheimer Risk', | |
| 'Very Mild Alzheimer Risk', | |
| 'No Risk', | |
| 'Parkinson Disease' | |
| ] | |
| def calculate_performance(model): | |
| model.eval() | |
| dummy = torch.randn(1,3,224,224).to(device) | |
| flops, params = profile(model, inputs=(dummy,), verbose=False) | |
| params_m = round(params/1e6,2) | |
| flops_b = round(flops/1e9,2) | |
| import time | |
| start = time.time() | |
| _ = model(dummy.cpu()) | |
| cpu_ms = round((time.time() - start)*1000,2) | |
| if device.type == 'cuda': | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| _ = model(dummy) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| gpu_ms = round(start_event.elapsed_time(end_event),2) | |
| else: | |
| gpu_ms = None | |
| return {'params_million':params_m, 'flops_billion':flops_b, 'cpu_ms':cpu_ms, 'gpu_ms':gpu_ms} | |
| def predict_and_monitor(version, image): | |
| try: | |
| if version not in models_cache: | |
| models_cache[version] = load_model(version, device) | |
| model = models_cache[version] | |
| if image is None: | |
| raise gr.Error("Görsel yüklenmedi.") | |
| img = image.convert("RGB") | |
| tensor = transform(img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = F.softmax(logits, dim=1)[0] | |
| pred_dict = {class_names[i]: round(float(probs[i]),4) for i in range(len(class_names))} | |
| metrics = calculate_performance(model) | |
| top1 = max(pred_dict, key=pred_dict.get) | |
| buf = io.BytesIO() | |
| plt.figure(figsize=(3,3)) | |
| plt.imshow(img) | |
| plt.title(f"{top1}: {pred_dict[top1]*100:.1f}%") | |
| plt.axis('off') | |
| plt.savefig(buf, format='png') | |
| plt.close() | |
| buf.seek(0) | |
| buf_image = Image.open(buf) | |
| return pred_dict, metrics, buf_image | |
| except Exception as e: | |
| raise gr.Error(f"Prediction Error: {e}") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("Dementia and Parkinson Diagnosis with Vbai-DPA 2.1(f,c,q)") | |
| with gr.Row(): | |
| version = gr.Radio(['f','c','q'], value='c', label="Model Version | f => Fastest, c => Classic, q => Quality") | |
| image_in = gr.Image(type="pil", label="MRI or fMRI Image") | |
| with gr.Row(): | |
| preds = gr.JSON(label="Prediction Probabilities") | |
| stats = gr.JSON(label="Performance Metrics") | |
| plot = gr.Image(label="Prediction") | |
| btn = gr.Button("Run") | |
| btn.click(fn=predict_and_monitor, inputs=[version, image_in], outputs=[preds, stats, plot]) | |
| if __name__ == '__main__': | |
| demo.launch() | |