alrichardbollans commited on
Commit
aef7d7f
·
1 Parent(s): 5f387da

Add styling and basic detectron functionality

Browse files
Files changed (5) hide show
  1. app.py +127 -166
  2. python_utils/__init__.py +1 -0
  3. python_utils/get_model.py +83 -0
  4. requirements.txt +0 -3
  5. styles.css +117 -7
app.py CHANGED
@@ -1,173 +1,134 @@
1
- import detectron2
2
- import torch
3
-
4
- # Check this in logs
5
- try:
6
- print(f"Is CUDA available: {torch.cuda.is_available()}")
7
- # True
8
- print(f"CUDA device: {torch.cuda.get_device_name(nn.cuda.current_device())}")
9
- except:
10
- print('Couldnt find CUDA device')
11
-
12
- import faicons as fa
13
- import plotly.express as px
14
-
 
 
 
 
 
 
 
15
  # Load data and compute static values
16
- from shared import app_dir, tips
17
- from shinywidgets import render_plotly
18
-
19
- from shiny import reactive, render
20
- from shiny.express import input, ui
21
-
22
- bill_rng = (min(tips.total_bill), max(tips.total_bill))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Add page title and sidebar
25
- ui.page_opts(title="Restaurant tipping", fillable=True)
26
 
27
- with ui.sidebar(open="desktop"):
28
- ui.input_slider(
29
- "total_bill",
30
- "Bill amount",
31
- min=bill_rng[0],
32
- max=bill_rng[1],
33
- value=bill_rng,
34
- pre="$",
35
- )
36
- ui.input_checkbox_group(
37
- "time",
38
- "Food service",
39
- ["Lunch", "Dinner"],
40
- selected=["Lunch", "Dinner"],
41
- inline=True,
42
- )
43
- ui.input_action_button("reset", "Reset filter")
44
-
45
- # Add main content
46
- ICONS = {
47
- "user": fa.icon_svg("user", "regular"),
48
- "wallet": fa.icon_svg("wallet"),
49
- "currency-dollar": fa.icon_svg("dollar-sign"),
50
- "ellipsis": fa.icon_svg("ellipsis"),
51
- }
52
-
53
- with ui.layout_columns(fill=False):
54
- with ui.value_box(showcase=ICONS["user"]):
55
- "Total tippers"
56
-
57
- @render.express
58
- def total_tippers():
59
- tips_data().shape[0]
60
-
61
- with ui.value_box(showcase=ICONS["wallet"]):
62
- "Average tip"
63
-
64
- @render.express
65
- def average_tip():
66
- d = tips_data()
67
- if d.shape[0] > 0:
68
- perc = d.tip / d.total_bill
69
- f"{perc.mean():.1%}"
70
-
71
- with ui.value_box(showcase=ICONS["currency-dollar"]):
72
- "Average bill"
73
-
74
- @render.express
75
- def average_bill():
76
- d = tips_data()
77
- if d.shape[0] > 0:
78
- bill = d.total_bill.mean()
79
- f"${bill:.2f}"
80
-
81
-
82
- with ui.layout_columns(col_widths=[6, 6, 12]):
83
- with ui.card(full_screen=True):
84
- ui.card_header("Tips data")
85
-
86
- @render.data_frame
87
- def table():
88
- return render.DataGrid(tips_data())
89
-
90
- with ui.card(full_screen=True):
91
- with ui.card_header(class_="d-flex justify-content-between align-items-center"):
92
- "Total bill vs tip"
93
- with ui.popover(title="Add a color variable", placement="top"):
94
- ICONS["ellipsis"]
95
- ui.input_radio_buttons(
96
- "scatter_color",
97
- None,
98
- ["none", "sex", "smoker", "day", "time"],
99
- inline=True,
100
- )
101
-
102
- @render_plotly
103
- def scatterplot():
104
- color = input.scatter_color()
105
- return px.scatter(
106
- tips_data(),
107
- x="total_bill",
108
- y="tip",
109
- color=None if color == "none" else color,
110
- trendline="lowess",
111
- )
112
-
113
- with ui.card(full_screen=True):
114
- with ui.card_header(class_="d-flex justify-content-between align-items-center"):
115
- "Tip percentages"
116
- with ui.popover(title="Add a color variable"):
117
- ICONS["ellipsis"]
118
- ui.input_radio_buttons(
119
- "tip_perc_y",
120
- "Split by:",
121
- ["sex", "smoker", "day", "time"],
122
- selected="day",
123
- inline=True,
124
- )
125
-
126
- @render_plotly
127
- def tip_perc():
128
- from ridgeplot import ridgeplot
129
-
130
- dat = tips_data()
131
- dat["percent"] = dat.tip / dat.total_bill
132
- yvar = input.tip_perc_y()
133
- uvals = dat[yvar].unique()
134
-
135
- samples = [[dat.percent[dat[yvar] == val]] for val in uvals]
136
-
137
- plt = ridgeplot(
138
- samples=samples,
139
- labels=uvals,
140
- bandwidth=0.01,
141
- colorscale="viridis",
142
- colormode="row-index",
143
- )
144
-
145
- plt.update_layout(
146
- legend=dict(
147
- orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5
148
- )
149
- )
150
-
151
- return plt
152
-
153
-
154
- ui.include_css(app_dir / "styles.css")
155
 
156
  # --------------------------------------------------------
157
  # Reactive calculations and effects
158
  # --------------------------------------------------------
159
-
160
-
161
- @reactive.calc
162
- def tips_data():
163
- bill = input.total_bill()
164
- idx1 = tips.total_bill.between(bill[0], bill[1])
165
- idx2 = tips.time.isin(input.time())
166
- return tips[idx1 & idx2]
167
-
168
-
169
- @reactive.effect
170
- @reactive.event(input.reset)
171
- def _():
172
- ui.update_slider("total_bill", value=bill_rng)
173
- ui.update_checkbox_group("time", selected=["Lunch", "Dinner"])
 
1
+ # import detectron2
2
+ # import torch
3
+ #
4
+ # # Check this in logs
5
+ # try:
6
+ # print(f"Is CUDA available: {torch.cuda.is_available()}")
7
+ # # True
8
+ # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
9
+ # except:
10
+ # print('Couldnt find CUDA device')
11
+
12
+ import base64
13
+ import tempfile
14
+ import cv2
15
+ from io import BytesIO
16
+
17
+ import pandas as pd
18
+ from PIL import Image
19
+ from shiny import App, ui, render, reactive, Session
20
+
21
+ from python_utils import load_model
22
  # Load data and compute static values
23
+ from shared import app_dir
24
+
25
+ # Load the prediction model
26
+ predictor = load_model()
27
+ app_ui = ui.page_fluid(
28
+ ui.include_css("styles.css"),
29
+ ui.panel_title(ui.div("Orchid TZ Viability Analyzer", class_="navbar-title")),
30
+ ui.div(
31
+ ui.download_button("download", "Download Results", class_="btn-primary"),
32
+ style="position: absolute; top: 10px; right: 10px;"
33
+ ),
34
+ ui.layout_sidebar(
35
+ ui.sidebar(
36
+ ui.input_file("upload", "Upload Images",
37
+ multiple=True,
38
+ accept=[".png", ".jpg", ".jpeg"]),
39
+ ui.input_action_button("analyze", "Analyze", class_="btn-success"),
40
+ width =300
41
+ ),
42
+ ui.output_ui("results_container"),
43
+ border=False,
44
+ border_radius=False
45
+ )
46
+ )
47
+
48
+
49
+
50
+ def server(input, output, session: Session):
51
+ analysis_results = reactive.Value([])
52
+
53
+ @reactive.Effect
54
+ @reactive.event(input.analyze)
55
+ async def process_images():
56
+ files = input.upload()
57
+ if not files:
58
+ return
59
+
60
+ results = []
61
+ with tempfile.TemporaryDirectory() as temp_dir:
62
+ for idx, file in enumerate(files):
63
+ # Read image using OpenCV
64
+ im = cv2.imread(file["datapath"])
65
+
66
+ # Convert BGR to RGB for display
67
+ im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
68
+ pil_img = Image.fromarray(im_rgb)
69
+
70
+ # Convert to base64 for HTML display
71
+ buffered = BytesIO()
72
+ pil_img.save(buffered, format="PNG")
73
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
74
+
75
+ # Run prediction with original BGR image
76
+ prediction = predictor(im)
77
+
78
+ results.append({
79
+ "filename": file["name"],
80
+ "image": img_base64,
81
+ **prediction
82
+ })
83
+
84
+ # Update reactive value
85
+ analysis_results.set(results)
86
+
87
+ @output
88
+ @render.ui
89
+ def results_container():
90
+ results = analysis_results.get()
91
+ if not results:
92
+ return ui.div("No results yet. Upload images and click 'Analyze'.",
93
+ class_="text-muted")
94
+
95
+ return ui.div(
96
+ [ui.div(
97
+ ui.row(
98
+ ui.column(4, ui.img(src=f"data:image/png;base64,{r['image']}")),
99
+ ui.column(4, ui.img(src=f"data:image/png;base64,{r['image']}")),
100
+ ),
101
+ ui.h5(r['filename'], style="margin-top: 15px;"),
102
+ ui.div(
103
+ ui.span(f"Viable = {r.get('viable', '?')}"),
104
+ ui.span(f"Nonviable = {r.get('nonviable', '?')}", style="margin: 0 15px;"),
105
+ ui.span(f"Empty = {r.get('empty', '?')}"),
106
+ class_="results-text"
107
+ ),
108
+ class_="card p-3"
109
+ ) for r in results]
110
+ )
111
+
112
+ @session.download()
113
+ def download():
114
+ results = analysis_results.get()
115
+ df = pd.DataFrame([{
116
+ "Filename": r["filename"],
117
+ "Viable": r.get("viable", ""),
118
+ "Nonviable": r.get("nonviable", ""),
119
+ "Empty": r.get("empty", "")
120
+ } for r in results])
121
+
122
+ # Create in-memory CSV file
123
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp:
124
+ df.to_csv(tmp.name, index=False)
125
+ return tmp.name
126
+
127
+
128
+ app = App(app_ui, server)
129
 
 
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # --------------------------------------------------------
133
  # Reactive calculations and effects
134
  # --------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python_utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .get_model import *
python_utils/get_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_set_up():
2
+ import torch
3
+ TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
4
+ CUDA_VERSION = torch.__version__.split("+")[-1]
5
+ print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
6
+ print(f'GPU available: {torch.cuda.is_available()}')
7
+ print(torch.cuda.get_device_capability())
8
+
9
+ # print("detectron2:", detectron2.__version__)
10
+
11
+
12
+ def load_model():
13
+ # def predictor(img):
14
+ # return {}
15
+ # return predictor
16
+ # import some common detectron2 utilities
17
+ import torch
18
+ from detectron2 import model_zoo
19
+ from detectron2.engine import DefaultPredictor
20
+ from detectron2.config import get_cfg
21
+ from detectron2.data.datasets import register_coco_instances
22
+
23
+ import os
24
+ import numpy as np
25
+
26
+ ## define relevant parameters
27
+ cfg = get_cfg()
28
+ cfg.merge_from_file(
29
+ model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
30
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4 # should be 3 after renaming 'Seed' class
31
+ if not torch.cuda.is_available():
32
+ cfg.MODEL.DEVICE = "cpu"
33
+ else:
34
+ cfg.MODEL.DEVICE = 'cuda'
35
+
36
+ register_coco_instances(
37
+ 'seeds', {"thing_classes": ['Seed', 'Viable', 'Non-Viable', 'Empty'],
38
+ "thing_colors": [(0, 0, 0), (0, 255, 0), (255, 0, 0), (0, 0, 255)]},
39
+ 'dataset1/train/annotations_train.json', 'dataset1/train/')
40
+ cfg.DATASETS.TRAIN = ("seeds",)
41
+
42
+ mean = [0.5, 0.2, 0.1]
43
+ std = [0.5, 0.1, 0.1] # mean_and_std("dataset1/Part1_COCO/images/train/")
44
+ cfg.MODEL.PIXEL_MEAN = np.array(mean, dtype=float).tolist()
45
+ cfg.MODEL.PIXEL_STD = np.array(std, dtype=float).tolist()
46
+
47
+ cfg.MODEL.ROI_HEADS.NAME = "CascadeROIHeads"
48
+ cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = True
49
+ cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = True
50
+ cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512, 1024]]
51
+ cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0]]
52
+ cfg.MODEL.FPN.IN_FEATURES = ["res2", "res3", "res4", "res5"]
53
+ cfg.MODEL.RPN.IN_FEATURES = ["p2", "p3", "p4", "p5"]
54
+ cfg.MODEL.ROI_HEADS.IN_FEATURES = ["p2", "p3", "p4", "p5"]
55
+ cfg.MODEL.FPN.NORM = "GN"
56
+ cfg.MODEL.ROI_BOX_HEAD.NORM = "GN"
57
+ cfg.MODEL.ROI_MASK_HEAD.NORM = "GN"
58
+ cfg.MODEL.RESNETS.NORM = "GN"
59
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
60
+ cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "norm"
61
+
62
+ cfg.MODEL.RPN.NMS_THRESH = 0.3
63
+ cfg.MODEL.RPN.PRE_NMS_TOPK_TEST = 12000
64
+ cfg.MODEL.RPN.POST_NMS_TOPK_TEST = 8000
65
+ # threshold for confidence
66
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
67
+ # removing overlapping bounding boxes of the same class
68
+ cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
69
+ # max number of instances per image
70
+ cfg.TEST.DETECTIONS_PER_IMAGE = 1200
71
+
72
+ ## Load trained model
73
+ ## Local files
74
+ # cfg.OUTPUT_DIR = "../YOLO/outputs/output20"
75
+ # cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
76
+ ## Or hugging face model
77
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
78
+ predictor = DefaultPredictor(cfg)
79
+ return predictor
80
+
81
+ if __name__ == '__main__':
82
+ # get_set_up()
83
+ load_model()
requirements.txt CHANGED
@@ -1,9 +1,6 @@
1
- faicons
2
  shiny
3
  shinywidgets
4
- plotly
5
  pandas
6
- ridgeplot
7
  opencv-python-headless
8
  pyyaml==5.1
9
  torch
 
 
1
  shiny
2
  shinywidgets
 
3
  pandas
 
4
  opencv-python-headless
5
  pyyaml==5.1
6
  torch
styles.css CHANGED
@@ -1,12 +1,122 @@
1
- :root {
2
- --bslib-sidebar-main-bg: #f8f8f8;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  }
4
 
5
- .popover {
6
- --bs-popover-header-bg: #222;
7
- --bs-popover-header-color: #fff;
8
  }
9
 
10
- .popover .btn-close {
11
- filter: var(--bs-btn-close-white-filter);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  }
 
1
+ /* www/styles.css */
2
+ /* Modern sleek theme with dark mode elements */
3
+
4
+ body {
5
+ background-color: #f8f9fa;
6
+ font-family: 'Segoe UI', system-ui, -apple-system, sans-serif;
7
+ }
8
+
9
+ .container-fluid {
10
+ padding: 20px;
11
+ max-width: 1400px;
12
+ margin: 0 auto;
13
+ }
14
+
15
+ /* Header styling */
16
+ .navbar-title {
17
+ color: #2c3e50 !important;
18
+ font-weight: 700;
19
+ font-size: 1.8rem;
20
+ padding: 15px 0;
21
+ }
22
+
23
+ /* Sidebar styling */
24
+ .card.shiny-input-container {
25
+ background-color: #ffffff;
26
+ border-radius: 12px;
27
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
28
+ padding: 20px;
29
+ margin-bottom: 20px;
30
+ }
31
+
32
+ /* Upload button styling */
33
+ .btn-file {
34
+ background-color: #4a90e2;
35
+ color: white !important;
36
+ border-radius: 8px;
37
+ padding: 10px 20px;
38
+ transition: all 0.3s ease;
39
+ }
40
+
41
+ .btn-file:hover {
42
+ background-color: #357abd;
43
+ transform: translateY(-1px);
44
+ }
45
+
46
+ /* Analyze button styling */
47
+ .btn-success {
48
+ background-color: #27ae60 !important;
49
+ border: none;
50
+ border-radius: 8px;
51
+ padding: 12px 25px;
52
+ font-weight: 600;
53
+ transition: all 0.3s ease;
54
  }
55
 
56
+ .btn-success:hover {
57
+ background-color: #219653 !important;
58
+ transform: translateY(-1px);
59
  }
60
 
61
+ /* Image cards styling */
62
+ .card {
63
+ background: white;
64
+ border: none;
65
+ border-radius: 15px;
66
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.08);
67
+ margin-bottom: 25px;
68
+ overflow: hidden;
69
+ transition: transform 0.2s ease;
70
+ }
71
+
72
+ .card:hover {
73
+ transform: translateY(-3px);
74
+ }
75
+
76
+ /* Image display styling */
77
+ img {
78
+ border-radius: 10px;
79
+ object-fit: cover;
80
+ max-height: 300px;
81
+ width: 100%;
82
+ margin: 10px 0;
83
+ box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1);
84
+ }
85
+
86
+ /* Results text styling */
87
+ .results-text {
88
+ color: #2c3e50;
89
+ font-family: 'Courier New', monospace;
90
+ font-size: 1.1rem;
91
+ margin: 15px 0;
92
+ padding: 12px;
93
+ background-color: #f8f9fa;
94
+ border-radius: 6px;
95
+ border-left: 4px solid #4a90e2;
96
+ }
97
+
98
+ /* Download button styling */
99
+ .btn-primary {
100
+ background-color: #2ecc71 !important;
101
+ border: none;
102
+ border-radius: 8px;
103
+ padding: 10px 25px;
104
+ font-weight: 600;
105
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
106
+ }
107
+
108
+ .btn-primary:hover {
109
+ background-color: #27ae60 !important;
110
+ }
111
+
112
+ /* Responsive design */
113
+ @media (max-width: 768px) {
114
+ .col-md-4 {
115
+ flex: 0 0 100%;
116
+ max-width: 100%;
117
+ }
118
+
119
+ img {
120
+ max-height: 200px;
121
+ }
122
  }