alidev2002 commited on
Commit
3b30feb
Β·
verified Β·
1 Parent(s): 3cbfd2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -20
app.py CHANGED
@@ -9,54 +9,98 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  model = OmniVoice.from_pretrained(
11
  "k2-fsa/OmniVoice",
12
- # load_asr=True,
13
  device_map="cuda:0" if device == "cuda" else "cpu",
14
  dtype=torch.float16 if device == "cuda" else torch.float32
15
  )
16
 
17
- # model2 = OmniVoice.from_pretrained(
18
- # "/root/.cache/huggingface/hub/models--k2-fsa--OmniVoice/snapshots/1d8c8a8fd2510535edab4f55aeae328b3e8a456e/",
19
- # load_asr=True,
20
- # asr_model_name="/root/.cache/huggingface/hub/models--openai--whisper-large-v3-turbo/snapshots/41f01f3fe87f28c78e2fbf8b568835947dd65ed9/",
21
- # device_map="cuda:0" if device == "cuda" else "cpu",
22
- # dtype=torch.float16 if device == "cuda" else torch.float32
23
- # )
24
 
25
- def generate(text, ref_audio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  output_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
28
 
 
 
29
  if ref_audio is not None:
30
  audio = model.generate(
31
  text=text,
 
32
  ref_audio=ref_audio,
33
- num_step=4,
34
- # ref_text optional
35
  )
36
  else:
37
  audio = model.generate(
38
  text=text,
39
- num_step=4,
 
40
  )
41
 
42
  sf.write(output_path, audio[0], 24000)
43
 
44
- # paths = []
45
- # for root, _, files in os.walk("/root/.cache/huggingface"):
46
- # for f in files:
47
- # paths.append(os.path.join(root, f))
48
- # return "\n".join(paths)
49
-
50
  return output_path
51
 
 
52
  demo = gr.Interface(
53
  fn=generate,
54
  inputs=[
55
  gr.Textbox(label="Text"),
56
- gr.Audio(type="filepath", label="Reference Voice (optional)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ],
 
58
  outputs=gr.Audio(type="filepath"),
59
- title="OmniVoice TTS (Voice Cloning)"
 
 
60
  )
61
 
62
  demo.launch()
 
9
 
10
  model = OmniVoice.from_pretrained(
11
  "k2-fsa/OmniVoice",
 
12
  device_map="cuda:0" if device == "cuda" else "cpu",
13
  dtype=torch.float16 if device == "cuda" else torch.float32
14
  )
15
 
16
+ def build_voice_prompt(gender, age, pitch, style):
17
+ attrs = []
 
 
 
 
 
18
 
19
+ if gender:
20
+ attrs.append(gender)
21
+ if age:
22
+ attrs.append(age)
23
+ if pitch:
24
+ attrs.append(pitch)
25
+ if style:
26
+ attrs.append(style)
27
+
28
+ if len(attrs) > 0:
29
+ voice_desc = ", ".join(attrs)
30
+ return voice_desc
31
+ else:
32
+ return None
33
+
34
+
35
+ def generate(text, ref_audio, gender, age, pitch, style, num_steps):
36
 
37
  output_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
38
 
39
+ instruct = build_voice_prompt(gender, age, pitch, style)
40
+
41
  if ref_audio is not None:
42
  audio = model.generate(
43
  text=text,
44
+ instruct=instruct,
45
  ref_audio=ref_audio,
46
+ num_step=int(num_steps),
 
47
  )
48
  else:
49
  audio = model.generate(
50
  text=text,
51
+ instruct=instruct,
52
+ num_step=int(num_steps),
53
  )
54
 
55
  sf.write(output_path, audio[0], 24000)
56
 
 
 
 
 
 
 
57
  return output_path
58
 
59
+
60
  demo = gr.Interface(
61
  fn=generate,
62
  inputs=[
63
  gr.Textbox(label="Text"),
64
+
65
+ gr.Audio(type="filepath", label="Reference Voice (optional)"),
66
+
67
+ gr.Dropdown(
68
+ choices=["male", "female"],
69
+ label="Gender",
70
+ value=None
71
+ ),
72
+
73
+ gr.Dropdown(
74
+ choices=["child", "teenager", "young adult", "middle-aged", "elderly"],
75
+ label="Age",
76
+ value=None
77
+ ),
78
+
79
+ gr.Dropdown(
80
+ choices=["very low pitch", "low pitch", "medium pitch", "high pitch"],
81
+ label="Pitch",
82
+ value=None
83
+ ),
84
+
85
+ # gr.Dropdown(
86
+ # choices=["normal", "whisper", "calm", "angry"],
87
+ # label="Style",
88
+ # value=None
89
+ # ),
90
+
91
+ gr.Slider(
92
+ minimum=1,
93
+ maximum=32,
94
+ value=4,
95
+ step=1,
96
+ label="num_steps"
97
+ ),
98
  ],
99
+
100
  outputs=gr.Audio(type="filepath"),
101
+
102
+ title="OmniVoice TTS (Voice Design + Cloning)",
103
+ description="Control voice with gender, age, pitch, style + num_steps"
104
  )
105
 
106
  demo.launch()