eyupipler commited on
Commit
0e93a38
·
verified ·
1 Parent(s): 4b299df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -66
app.py CHANGED
@@ -3,53 +3,26 @@ import torch
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
6
- from torchvision.transforms.functional import to_pil_image
7
  from model import load_model
8
  import matplotlib.pyplot as plt
9
  import numpy as np
10
  from thop import profile
11
  import io
12
 
13
- def calculate_performance_metrics(model, device, input_size=(1,3,224,224)):
14
- model.to(device)
15
- inputs = torch.randn(input_size).to(device)
16
- flops, params = profile(model, inputs=(inputs,), verbose=False)
17
- params_million = params / 1e6
18
- flops_billion = flops / 1e9
19
- # timing
20
- with torch.no_grad():
21
- start = torch.cuda.Event(enable_timing=True)
22
- end = torch.cuda.Event(enable_timing=True)
23
- start.record()
24
- _ = model(inputs)
25
- end.record()
26
- torch.cuda.synchronize()
27
- speed_gpu_ms = start.elapsed_time(end)
28
- # CPU timing
29
- inputs_cpu = inputs.to('cpu')
30
- start_c = torch.cuda.Event(enable_timing=True)
31
- end_c = torch.cuda.Event(enable_timing=True)
32
- # use time.time as fallback for CPU
33
- import time
34
- t0 = time.time()
35
- _ = model(inputs_cpu)
36
- t1 = time.time()
37
- speed_cpu_ms = (t1 - t0) * 1000
38
- return {
39
- 'params_million': round(params_million,2),
40
- 'flops_billion': round(flops_billion,2),
41
- 'speed_cpu_ms': round(speed_cpu_ms,2),
42
- 'speed_gpu_ms': round(speed_gpu_ms,2)
43
- }
44
 
45
- # Preprocess transform
46
- def get_transform():
47
- return transforms.Compose([
48
- transforms.Resize((224,224)),
49
- transforms.ToTensor(),
50
- transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
51
- ])
52
 
 
53
  class_names = [
54
  'Alzheimer Disease',
55
  'Mild Alzheimer Risk',
@@ -59,42 +32,80 @@ class_names = [
59
  'Parkinson Disease'
60
  ]
61
 
62
- # Gradio predict function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def predict_and_monitor(version, image):
64
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
- model = load_model(version, device)
66
- # preprocess
67
- img = image.convert("RGB")
68
- tensor = get_transform()(img).unsqueeze(0).to(device)
69
- # inference
70
- with torch.no_grad():
71
- outputs = model(tensor)
72
- probs = F.softmax(outputs, dim=1)[0]
73
- # prepare top3
74
- topk = torch.topk(probs, k=3)
75
- pred_items = {class_names[i]: round(float(probs[i]),4) for i in range(len(class_names))}
76
- # metrics
77
- metrics = calculate_performance_metrics(model, device)
78
- # plot input image with prediction
79
- buf = io.BytesIO()
80
- plt.figure(figsize=(4,4))
81
- plt.imshow(img)
82
- plt.title(f"Top1: {topk.indices[0]} ({topk.values[0]:.4f})")
83
- plt.axis('off')
84
- plt.savefig(buf, format='png')
85
- plt.close()
86
- buf.seek(0)
87
- return pred_items, metrics, buf
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
89
  with gr.Blocks() as demo:
90
- gr.Markdown("# Vbai-DPA 2.1 Risk Classification & Monitoring")
91
  with gr.Row():
92
  version = gr.Radio(['f','c','q'], value='c', label="Model Version")
93
  image_in = gr.Image(type="pil", label="Brain Slice (224x224)")
94
  with gr.Row():
95
  preds = gr.JSON(label="Prediction Probabilities")
96
  stats = gr.JSON(label="Performance Metrics")
97
- plot = gr.Image(label="Input & Top-1");
98
  btn = gr.Button("Run")
99
  btn.click(fn=predict_and_monitor, inputs=[version, image_in], outputs=[preds, stats, plot])
100
 
 
3
  import torch.nn.functional as F
4
  from torchvision import transforms
5
  from PIL import Image
 
6
  from model import load_model
7
  import matplotlib.pyplot as plt
8
  import numpy as np
9
  from thop import profile
10
  import io
11
 
12
+ # Device selection
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # Cache models to avoid repeated downloads
16
+ models_cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Preprocess transform for 224x224 input
19
+ transform = transforms.Compose([
20
+ transforms.Resize((224,224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
23
+ ])
 
24
 
25
+ # Class names
26
  class_names = [
27
  'Alzheimer Disease',
28
  'Mild Alzheimer Risk',
 
32
  'Parkinson Disease'
33
  ]
34
 
35
+ # Performance metrics calculation outside predict to not block UI
36
+ def calculate_performance(model):
37
+ model.eval()
38
+ dummy = torch.randn(1,3,224,224).to(device)
39
+ flops, params = profile(model, inputs=(dummy,), verbose=False)
40
+ params_m = round(params/1e6,2)
41
+ flops_b = round(flops/1e9,2)
42
+ # inference timing on CPU
43
+ import time
44
+ start = time.time()
45
+ _ = model(dummy.cpu())
46
+ cpu_ms = round((time.time() - start)*1000,2)
47
+ # inference timing on GPU if available
48
+ if device.type == 'cuda':
49
+ start_event = torch.cuda.Event(enable_timing=True)
50
+ end_event = torch.cuda.Event(enable_timing=True)
51
+ start_event.record()
52
+ _ = model(dummy)
53
+ end_event.record()
54
+ torch.cuda.synchronize()
55
+ gpu_ms = round(start_event.elapsed_time(end_event),2)
56
+ else:
57
+ gpu_ms = None
58
+ return {'params_million':params_m, 'flops_billion':flops_b, 'cpu_ms':cpu_ms, 'gpu_ms':gpu_ms}
59
+
60
+ # Prediction function
61
  def predict_and_monitor(version, image):
62
+ try:
63
+ # load or get cached model
64
+ if version not in models_cache:
65
+ models_cache[version] = load_model(version, device)
66
+ model = models_cache[version]
67
+
68
+ # preprocess
69
+ if image is None:
70
+ raise gr.Error("Görsel yüklenmedi.")
71
+ img = image.convert("RGB")
72
+ tensor = transform(img).unsqueeze(0).to(device)
73
+
74
+ # inference
75
+ with torch.no_grad():
76
+ logits = model(tensor)
77
+ probs = F.softmax(logits, dim=1)[0]
78
+
79
+ # prepare outputs
80
+ pred_dict = {class_names[i]: round(float(probs[i]),4) for i in range(len(class_names))}
81
+ metrics = calculate_performance(model)
82
+
83
+ # plot image with top1 label
84
+ top1 = max(pred_dict, key=pred_dict.get)
85
+ buf = io.BytesIO()
86
+ plt.figure(figsize=(3,3))
87
+ plt.imshow(img)
88
+ plt.title(f"{top1}: {pred_dict[top1]*100:.1f}%")
89
+ plt.axis('off')
90
+ plt.savefig(buf, format='png')
91
+ plt.close()
92
+ buf.seek(0)
93
+
94
+ return pred_dict, metrics, buf
95
+ except Exception as e:
96
+ # show exception message
97
+ raise gr.Error(f"Tahmin hatası: {e}")
98
 
99
+ # Gradio interface
100
  with gr.Blocks() as demo:
101
+ gr.Markdown("# Vbai-DPA Risk Classification & Monitoring")
102
  with gr.Row():
103
  version = gr.Radio(['f','c','q'], value='c', label="Model Version")
104
  image_in = gr.Image(type="pil", label="Brain Slice (224x224)")
105
  with gr.Row():
106
  preds = gr.JSON(label="Prediction Probabilities")
107
  stats = gr.JSON(label="Performance Metrics")
108
+ plot = gr.Image(label="Input & Top1")
109
  btn = gr.Button("Run")
110
  btn.click(fn=predict_and_monitor, inputs=[version, image_in], outputs=[preds, stats, plot])
111