apehex commited on
Commit
75344c6
·
1 Parent(s): 0958d5d

Warn the user when the calculations are aborted.

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -97,16 +97,22 @@ def compute_indices(
97
 
98
  # LOGITS #######################################################################
99
 
100
- @spaces.GPU(duration=60)
101
  def compute_logits(
102
  indices_arr: object,
103
  ) -> object:
104
- # load the model and tokenizer inside the GPU wrapper
 
105
  __model = fetch_model()
106
- # fill all the arguments that cannot be pickled
107
- return _app.update_logits_state(
108
- indices_arr=indices_arr,
109
- model_obj=__model)
 
 
 
 
 
110
 
111
  # MAIN #########################################################################
112
 
 
97
 
98
  # LOGITS #######################################################################
99
 
100
+ @spaces.GPU(duration=30)
101
  def compute_logits(
102
  indices_arr: object,
103
  ) -> object:
104
+ __logits = None
105
+ # load the model inside the GPU wrapper (not before)
106
  __model = fetch_model()
107
+ # the allocation might expire before the calculations are finished
108
+ try:
109
+ __logits = _app.update_logits_state(
110
+ indices_arr=indices_arr,
111
+ model_obj=__model)
112
+ except:
113
+ gradio.Warning(title='Warning', message='Calculations aborted because the GPU allocation expired.', duration=4)
114
+ # tensor or None
115
+ return __logits
116
 
117
  # MAIN #########################################################################
118