ABLingss commited on
Commit
0b453ea
·
1 Parent(s): 89fe941
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -29,9 +29,31 @@ try:
29
  except Exception:
30
  spaces = None
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def _gpu_guard(fn):
33
  if spaces is None:
34
  return fn
 
 
 
 
 
35
  return spaces.GPU(fn)
36
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
37
 
@@ -483,7 +505,6 @@ def load_transcriptor(model_path):
483
  return manager.get_transcriptor()
484
 
485
 
486
- @_gpu_guard
487
  def generate(
488
  lyrics,
489
  tags,
@@ -547,6 +568,7 @@ def generate(
547
  raise gr.Error(f"Generation error: {str(e)}")
548
 
549
 
 
550
  def generate_original(
551
  lyrics,
552
  tags,
@@ -580,6 +602,7 @@ def generate_original(
580
  )
581
 
582
 
 
583
  def generate_accelerated(
584
  lyrics,
585
  tags,
@@ -687,7 +710,6 @@ def generate_music_streaming(
687
  yield 48000, chunk_np
688
 
689
 
690
- @_gpu_guard
691
  def stream_generate(
692
  lyrics,
693
  tags,
@@ -754,6 +776,7 @@ def stream_generate(
754
  raise gr.Error(f"Streaming error: {str(e)}")
755
 
756
 
 
757
  def stream_generate_accelerated(
758
  lyrics,
759
  tags,
@@ -1038,7 +1061,7 @@ def create_ui():
1038
  label="Quantization"
1039
  )
1040
  keep_model_loaded = gr.Checkbox(
1041
- value=True,
1042
  label="Keep Model Loaded"
1043
  )
1044
  offload_mode = gr.Dropdown(
 
29
  except Exception:
30
  spaces = None
31
 
32
+ GPU_MAX_DURATION = int(os.environ.get("GPU_MAX_DURATION", "400"))
33
+
34
+
35
+ def _env_bool(name: str) -> Optional[bool]:
36
+ val = os.environ.get(name)
37
+ if val is None:
38
+ return None
39
+ return val.strip().lower() in ("1", "true", "yes", "y", "on")
40
+
41
+
42
+ _default_keep_model_loaded_env = _env_bool("KEEP_MODEL_LOADED_DEFAULT")
43
+ if _default_keep_model_loaded_env is None:
44
+ DEFAULT_KEEP_MODEL_LOADED = not (spaces is not None and os.environ.get("SPACE_ID"))
45
+ else:
46
+ DEFAULT_KEEP_MODEL_LOADED = _default_keep_model_loaded_env
47
+
48
+
49
  def _gpu_guard(fn):
50
  if spaces is None:
51
  return fn
52
+ if GPU_MAX_DURATION > 0:
53
+ try:
54
+ return spaces.GPU(fn, duration=GPU_MAX_DURATION)
55
+ except TypeError:
56
+ return spaces.GPU(fn)
57
  return spaces.GPU(fn)
58
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
59
 
 
505
  return manager.get_transcriptor()
506
 
507
 
 
508
  def generate(
509
  lyrics,
510
  tags,
 
568
  raise gr.Error(f"Generation error: {str(e)}")
569
 
570
 
571
+ @_gpu_guard
572
  def generate_original(
573
  lyrics,
574
  tags,
 
602
  )
603
 
604
 
605
+ @_gpu_guard
606
  def generate_accelerated(
607
  lyrics,
608
  tags,
 
710
  yield 48000, chunk_np
711
 
712
 
 
713
  def stream_generate(
714
  lyrics,
715
  tags,
 
776
  raise gr.Error(f"Streaming error: {str(e)}")
777
 
778
 
779
+ @_gpu_guard
780
  def stream_generate_accelerated(
781
  lyrics,
782
  tags,
 
1061
  label="Quantization"
1062
  )
1063
  keep_model_loaded = gr.Checkbox(
1064
+ value=DEFAULT_KEEP_MODEL_LOADED,
1065
  label="Keep Model Loaded"
1066
  )
1067
  offload_mode = gr.Dropdown(