Spaces:
Sleeping
Sleeping
Update my_model/tabs/run_inference.py
Browse files
my_model/tabs/run_inference.py
CHANGED
|
@@ -4,6 +4,7 @@ import bitsandbytes
|
|
| 4 |
import accelerate
|
| 5 |
import scipy
|
| 6 |
import copy
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
import torch.nn as nn
|
| 9 |
import pandas as pd
|
|
@@ -32,6 +33,7 @@ class InferenceRunner(StateManager):
|
|
| 32 |
# Display sample images as clickable thumbnails
|
| 33 |
self.col1.write("Choose from sample images:")
|
| 34 |
cols = self.col1.columns(len(self.sample_images))
|
|
|
|
| 35 |
for idx, sample_image_path in enumerate(self.sample_images):
|
| 36 |
with cols[idx]:
|
| 37 |
image = Image.open(sample_image_path)
|
|
@@ -108,7 +110,7 @@ class InferenceRunner(StateManager):
|
|
| 108 |
with st.container():
|
| 109 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
| 110 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
| 111 |
-
|
| 112 |
if st.session_state.button_label == "Load Model":
|
| 113 |
if self.is_model_loaded():
|
| 114 |
free_gpu_resources()
|
|
@@ -121,10 +123,12 @@ class InferenceRunner(StateManager):
|
|
| 121 |
|
| 122 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
| 123 |
force_reload_full_model = True
|
|
|
|
| 124 |
|
| 125 |
if load_fine_tuned_model:
|
| 126 |
free_gpu_resources()
|
| 127 |
self.load_model()
|
|
|
|
| 128 |
st.session_state['loading_in_progress'] = False
|
| 129 |
|
| 130 |
elif fine_tuned_model_already_loaded:
|
|
@@ -139,8 +143,11 @@ class InferenceRunner(StateManager):
|
|
| 139 |
|
| 140 |
elif force_reload_full_model:
|
| 141 |
free_gpu_resources()
|
|
|
|
| 142 |
self.force_reload_model()
|
|
|
|
| 143 |
st.session_state['loading_in_progress'] = False
|
|
|
|
| 144 |
|
| 145 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
| 146 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
|
@@ -148,8 +155,9 @@ class InferenceRunner(StateManager):
|
|
| 148 |
|
| 149 |
|
| 150 |
if self.is_model_loaded():
|
| 151 |
-
st.session_state['
|
| 152 |
free_gpu_resources()
|
|
|
|
| 153 |
self.image_qa_app(self.get_model())
|
| 154 |
st.write(st.session_state['loading_in_progress'])
|
| 155 |
|
|
|
|
| 4 |
import accelerate
|
| 5 |
import scipy
|
| 6 |
import copy
|
| 7 |
+
import time
|
| 8 |
from PIL import Image
|
| 9 |
import torch.nn as nn
|
| 10 |
import pandas as pd
|
|
|
|
| 33 |
# Display sample images as clickable thumbnails
|
| 34 |
self.col1.write("Choose from sample images:")
|
| 35 |
cols = self.col1.columns(len(self.sample_images))
|
| 36 |
+
st.write(st.session_state['loading_in_progress'])
|
| 37 |
for idx, sample_image_path in enumerate(self.sample_images):
|
| 38 |
with cols[idx]:
|
| 39 |
image = Image.open(sample_image_path)
|
|
|
|
| 110 |
with st.container():
|
| 111 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
| 112 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
| 113 |
+
t1=time.time()
|
| 114 |
if st.session_state.button_label == "Load Model":
|
| 115 |
if self.is_model_loaded():
|
| 116 |
free_gpu_resources()
|
|
|
|
| 123 |
|
| 124 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
| 125 |
force_reload_full_model = True
|
| 126 |
+
t1=time.time()
|
| 127 |
|
| 128 |
if load_fine_tuned_model:
|
| 129 |
free_gpu_resources()
|
| 130 |
self.load_model()
|
| 131 |
+
|
| 132 |
st.session_state['loading_in_progress'] = False
|
| 133 |
|
| 134 |
elif fine_tuned_model_already_loaded:
|
|
|
|
| 143 |
|
| 144 |
elif force_reload_full_model:
|
| 145 |
free_gpu_resources()
|
| 146 |
+
|
| 147 |
self.force_reload_model()
|
| 148 |
+
|
| 149 |
st.session_state['loading_in_progress'] = False
|
| 150 |
+
st.session_state['model_loaded'] = True
|
| 151 |
|
| 152 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
| 153 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
|
|
|
| 155 |
|
| 156 |
|
| 157 |
if self.is_model_loaded():
|
| 158 |
+
st.session_state['time_taken_to_load_model'] = time.time()-t1
|
| 159 |
free_gpu_resources()
|
| 160 |
+
st.session_state['loading_in_progress'] = False
|
| 161 |
self.image_qa_app(self.get_model())
|
| 162 |
st.write(st.session_state['loading_in_progress'])
|
| 163 |
|