alex commited on
Commit
4127fe1
·
1 Parent(s): 80f7eb2

session id fixes

Browse files
Files changed (1) hide show
  1. app.py +41 -18
app.py CHANGED
@@ -82,13 +82,13 @@ from datetime import timedelta
82
  import torchaudio
83
  import tigersound.look2hear.models
84
 
85
- @spaces.GPU()
86
- def print_ort():
87
 
88
- import onnxruntime as ort
89
- print(ort.get_available_providers())
90
 
91
- print_ort()
92
 
93
  current_dir = os.path.dirname(os.path.abspath(__file__))
94
  snapshot_download("IndexTeam/IndexTTS-2", local_dir=os.path.join(current_dir,"checkpoints"))
@@ -556,21 +556,29 @@ def build_srt(segments: List[Dict], audio_wav: str, out_srt_path: str):
556
  with open(out_srt_path, "w", encoding="utf-8") as f:
557
  f.write(srt.compose(subtitles))
558
 
559
- def translate_video(video_file, duration):
560
- return process_video(video_file, False, duration)
 
 
 
 
561
 
562
- def translate_lipsync_video(video_file, duration):
563
- return process_video(video_file, True, duration)
 
 
 
 
564
 
565
 
566
- def run_example(video_file, allow_lipsync, duration):
567
 
568
  with timer("processed"):
569
- result = process_video(video_file, allow_lipsync, duration)
570
 
571
  return result
572
 
573
- def get_duration(video_file, allow_lipsync, duration):
574
 
575
  if allow_lipsync:
576
  if duration <= 3:
@@ -587,16 +595,16 @@ def get_duration(video_file, allow_lipsync, duration):
587
  return 40
588
 
589
  @spaces.GPU(duration=get_duration)
590
- def process_video(video_file, allow_lipsync, duration):
591
  """
592
  Gradio callback:
593
  - video_file: temp file object/path from Gradio
594
  - returns path to generated SRT file (for download)
595
  """
596
- if video_file is None:
597
- raise gr.Error("Please upload an MP4 video.")
598
 
599
- session_id = uuid.uuid4().hex
 
600
 
601
  output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
602
  os.makedirs(output_dir, exist_ok=True)
@@ -997,9 +1005,23 @@ css = """
997
  }
998
  """
999
 
 
 
 
 
 
 
 
 
 
 
 
1000
 
1001
  with gr.Blocks(css=css) as demo:
1002
 
 
 
 
1003
  with gr.Column(elem_id="col-container"):
1004
  gr.HTML(
1005
  """
@@ -1100,17 +1122,18 @@ with gr.Blocks(css=css) as demo:
1100
 
1101
  translate_btn.click(
1102
  fn=translate_video,
1103
- inputs=[video_input, duration],
1104
  outputs=[video_output, srt_output, vocal_16k_output],
1105
  )
1106
 
1107
  translate_lipsync_btn.click(
1108
  fn=translate_lipsync_video,
1109
- inputs=[video_input, duration],
1110
  outputs=[video_output, srt_output, vocal_16k_output],
1111
  )
1112
 
1113
 
1114
  if __name__ == "__main__":
 
1115
  demo.queue()
1116
  demo.launch()
 
82
  import torchaudio
83
  import tigersound.look2hear.models
84
 
85
+ # @spaces.GPU()
86
+ # def print_ort():
87
 
88
+ # import onnxruntime as ort
89
+ # print(ort.get_available_providers())
90
 
91
+ # print_ort()
92
 
93
  current_dir = os.path.dirname(os.path.abspath(__file__))
94
  snapshot_download("IndexTeam/IndexTTS-2", local_dir=os.path.join(current_dir,"checkpoints"))
 
556
  with open(out_srt_path, "w", encoding="utf-8") as f:
557
  f.write(srt.compose(subtitles))
558
 
559
+ def translate_video(video_file, duration, session_id = None):
560
+
561
+ if video_file is None:
562
+ raise gr.Error("Please upload a clip.")
563
+
564
+ return process_video(video_file, False, duration, session_id)
565
 
566
+ def translate_lipsync_video(video_file, duration, session_id = None):
567
+
568
+ if video_file is None:
569
+ raise gr.Error("Please upload a clip.")
570
+
571
+ return process_video(video_file, True, duration, session_id)
572
 
573
 
574
+ def run_example(video_file, allow_lipsync, duration, session_id = None):
575
 
576
  with timer("processed"):
577
+ result = process_video(video_file, allow_lipsync, duration, session_id)
578
 
579
  return result
580
 
581
+ def get_duration(video_file, allow_lipsync, duration, session_id):
582
 
583
  if allow_lipsync:
584
  if duration <= 3:
 
595
  return 40
596
 
597
  @spaces.GPU(duration=get_duration)
598
+ def process_video(video_file, allow_lipsync, duration, session_id = None):
599
  """
600
  Gradio callback:
601
  - video_file: temp file object/path from Gradio
602
  - returns path to generated SRT file (for download)
603
  """
604
+ import onnxruntime as ort
 
605
 
606
+ if session_id == None:
607
+ session_id = uuid.uuid4().hex
608
 
609
  output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
610
  os.makedirs(output_dir, exist_ok=True)
 
1005
  }
1006
  """
1007
 
1008
+ def cleanup(request: gr.Request):
1009
+
1010
+ sid = request.session_hash
1011
+ if sid:
1012
+ print(f"{sid} left")
1013
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
1014
+ shutil.rmtree(d1, ignore_errors=True)
1015
+
1016
+ def start_session(request: gr.Request):
1017
+
1018
+ return request.session_hash
1019
 
1020
  with gr.Blocks(css=css) as demo:
1021
 
1022
+ session_state = gr.State()
1023
+ demo.load(start_session, outputs=[session_state])
1024
+
1025
  with gr.Column(elem_id="col-container"):
1026
  gr.HTML(
1027
  """
 
1122
 
1123
  translate_btn.click(
1124
  fn=translate_video,
1125
+ inputs=[video_input, duration, session_state],
1126
  outputs=[video_output, srt_output, vocal_16k_output],
1127
  )
1128
 
1129
  translate_lipsync_btn.click(
1130
  fn=translate_lipsync_video,
1131
+ inputs=[video_input, duration, session_state],
1132
  outputs=[video_output, srt_output, vocal_16k_output],
1133
  )
1134
 
1135
 
1136
  if __name__ == "__main__":
1137
+ demo.unload(cleanup)
1138
  demo.queue()
1139
  demo.launch()