legolasyiu commited on
Commit
430aac7
·
verified ·
1 Parent(s): 940de6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -29,7 +29,7 @@ print("Loading STT model...")
29
  stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
30
  stt_model = AutoModelForImageTextToText.from_pretrained(
31
  STT_MODEL_ID,
32
- torch_dtype=DTYPE,
33
  device_map="auto",
34
  )
35
 
@@ -37,8 +37,8 @@ print("Loading TTS model...")
37
  tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID)
38
  tts_model = AutoModelForCausalLM.from_pretrained(
39
  TTS_MODEL_ID,
40
- torch_dtype=DTYPE,
41
- ).to(DEVICE)
42
 
43
  # -----------------------------
44
  # PIPELINE FUNCTION
@@ -73,7 +73,7 @@ def speech_to_speech(audio_file):
73
  tts_inputs = tts_tokenizer(
74
  transcription,
75
  return_tensors="pt",
76
- ).to(DEVICE)
77
 
78
  with torch.no_grad():
79
  speech = tts_model.generate(**tts_inputs)
 
29
  stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID)
30
  stt_model = AutoModelForImageTextToText.from_pretrained(
31
  STT_MODEL_ID,
32
+ torch_dtype="auto",
33
  device_map="auto",
34
  )
35
 
 
37
  tts_tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID)
38
  tts_model = AutoModelForCausalLM.from_pretrained(
39
  TTS_MODEL_ID,
40
+ torch_dtype="auto",
41
+ )
42
 
43
  # -----------------------------
44
  # PIPELINE FUNCTION
 
73
  tts_inputs = tts_tokenizer(
74
  transcription,
75
  return_tensors="pt",
76
+ )
77
 
78
  with torch.no_grad():
79
  speech = tts_model.generate(**tts_inputs)