eyupipler commited on
Commit
fee9d86
·
verified ·
1 Parent(s): a0f945e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -47
app.py CHANGED
@@ -1,16 +1,55 @@
1
  import gradio as gr
2
  import torch
 
3
  from torchvision import transforms
4
  from PIL import Image
 
5
  from model import load_model
 
 
 
 
6
 
7
- # Device selection
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Model cache
11
- models_cache = {}
 
 
 
 
 
12
 
13
- # Common class names
14
  class_names = [
15
  'Alzheimer Disease',
16
  'Mild Alzheimer Risk',
@@ -20,50 +59,44 @@ class_names = [
20
  'Parkinson Disease'
21
  ]
22
 
23
- # Image preprocessing
24
- transform = transforms.Compose([
25
- transforms.Resize((448, 448)),
26
- transforms.ToTensor(),
27
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
28
- std=[0.229, 0.224, 0.225])
29
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def predict(version: str, image: Image.Image):
32
- """
33
- Predict Alzheimer/Parkinson risk using selected version (f, c, q) on a 2D brain slice.
34
- """
35
- try:
36
- # Load model if not cached
37
- if version not in models_cache:
38
- models_cache[version] = load_model(version, device)
39
- model = models_cache[version]
40
-
41
- # Convert and preprocess image
42
- img = image.convert("RGB")
43
- tensor = transform(img).unsqueeze(0).to(device)
44
-
45
- # Inference
46
- with torch.no_grad():
47
- outputs = model(tensor)
48
- probs = torch.nn.functional.softmax(outputs, dim=1)[0]
49
-
50
- # Return full probability mapping
51
- return {class_names[i]: float(probs[i]) for i in range(len(class_names))}
52
- except Exception as e:
53
- # Raise a Gradio error to display in UI
54
- raise gr.Error(f"Inference error: {str(e)}")
55
-
56
- # Build Gradio interface
57
  with gr.Blocks() as demo:
58
- gr.Markdown("## 🧠 Vbai-DPA 2.1 Alzheimer & Parkinson Risk Classification")
59
- gr.Markdown("Seçmek istediğin modeli (f, c veya q) seç ve bir 2D beyin dilimi yükle.")
60
-
 
61
  with gr.Row():
62
- version_selector = gr.Radio(choices=['f', 'c', 'q'], value='c', label="Model Version")
63
- image_input = gr.Image(type="pil", label="Brain Slice Image")
64
- output = gr.JSON(label="Olasılıklar (JSON)") # Use JSON output to handle errors and full mapping
65
- predict_btn = gr.Button("Tahmin Et")
66
- predict_btn.click(fn=predict, inputs=[version_selector, image_input], outputs=output)
67
 
68
- if __name__ == "__main__":
69
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
+ from torchvision.transforms.functional import to_pil_image
7
  from model import load_model
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from thop import profile
11
+ import io
12
 
13
+ def calculate_performance_metrics(model, device, input_size=(1,3,224,224)):
14
+ model.to(device)
15
+ inputs = torch.randn(input_size).to(device)
16
+ flops, params = profile(model, inputs=(inputs,), verbose=False)
17
+ params_million = params / 1e6
18
+ flops_billion = flops / 1e9
19
+ # timing
20
+ with torch.no_grad():
21
+ start = torch.cuda.Event(enable_timing=True)
22
+ end = torch.cuda.Event(enable_timing=True)
23
+ start.record()
24
+ _ = model(inputs)
25
+ end.record()
26
+ torch.cuda.synchronize()
27
+ speed_gpu_ms = start.elapsed_time(end)
28
+ # CPU timing
29
+ inputs_cpu = inputs.to('cpu')
30
+ start_c = torch.cuda.Event(enable_timing=True)
31
+ end_c = torch.cuda.Event(enable_timing=True)
32
+ # use time.time as fallback for CPU
33
+ import time
34
+ t0 = time.time()
35
+ _ = model(inputs_cpu)
36
+ t1 = time.time()
37
+ speed_cpu_ms = (t1 - t0) * 1000
38
+ return {
39
+ 'params_million': round(params_million,2),
40
+ 'flops_billion': round(flops_billion,2),
41
+ 'speed_cpu_ms': round(speed_cpu_ms,2),
42
+ 'speed_gpu_ms': round(speed_gpu_ms,2)
43
+ }
44
 
45
+ # Preprocess transform
46
+ def get_transform():
47
+ return transforms.Compose([
48
+ transforms.Resize((224,224)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
51
+ ])
52
 
 
53
  class_names = [
54
  'Alzheimer Disease',
55
  'Mild Alzheimer Risk',
 
59
  'Parkinson Disease'
60
  ]
61
 
62
+ # Gradio predict function
63
+ def predict_and_monitor(version, image):
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ model = load_model(version, device)
66
+ # preprocess
67
+ img = image.convert("RGB")
68
+ tensor = get_transform()(img).unsqueeze(0).to(device)
69
+ # inference
70
+ with torch.no_grad():
71
+ outputs = model(tensor)
72
+ probs = F.softmax(outputs, dim=1)[0]
73
+ # prepare top3
74
+ topk = torch.topk(probs, k=3)
75
+ pred_items = {class_names[i]: round(float(probs[i]),4) for i in range(len(class_names))}
76
+ # metrics
77
+ metrics = calculate_performance_metrics(model, device)
78
+ # plot input image with prediction
79
+ buf = io.BytesIO()
80
+ plt.figure(figsize=(4,4))
81
+ plt.imshow(img)
82
+ plt.title(f"Top1: {topk.indices[0]} ({topk.values[0]:.4f})")
83
+ plt.axis('off')
84
+ plt.savefig(buf, format='png')
85
+ plt.close()
86
+ buf.seek(0)
87
+ return pred_items, metrics, buf
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  with gr.Blocks() as demo:
90
+ gr.Markdown("# Vbai-DPA 2.1 Risk Classification & Monitoring")
91
+ with gr.Row():
92
+ version = gr.Radio(['f','c','q'], value='c', label="Model Version")
93
+ image_in = gr.Image(type="pil", label="Brain Slice (224x224)")
94
  with gr.Row():
95
+ preds = gr.JSON(label="Prediction Probabilities")
96
+ stats = gr.JSON(label="Performance Metrics")
97
+ plot = gr.Image(label="Input & Top-1");
98
+ btn = gr.Button("Run")
99
+ btn.click(fn=predict_and_monitor, inputs=[version, image_in], outputs=[preds, stats, plot])
100
 
101
+ if __name__ == '__main__':
102
  demo.launch()