Spaces:
Runtime error
Runtime error
root commited on
Commit ·
eb8bfb7
1
Parent(s): 09e3553
compatible with L40
Browse files- Dockerfile +1 -1
- app.py +14 -14
- levo_inference.py +1 -1
- vllm_hacked/v1/sample/sampler.py +4 -1
Dockerfile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
FROM witszhang/songgeneration_vllm:
|
| 2 |
|
| 3 |
USER root
|
| 4 |
|
|
|
|
| 1 |
+
FROM witszhang/songgeneration_vllm:v2
|
| 2 |
|
| 3 |
USER root
|
| 4 |
|
app.py
CHANGED
|
@@ -183,20 +183,6 @@ lyrics
|
|
| 183 |
)
|
| 184 |
|
| 185 |
with gr.Tabs(elem_id="extra-tabs"):
|
| 186 |
-
with gr.Tab("Genre Select"):
|
| 187 |
-
genre = gr.Radio(
|
| 188 |
-
choices=["Auto", "Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera"],
|
| 189 |
-
label="Genre Select(Optional)",
|
| 190 |
-
value="Auto",
|
| 191 |
-
interactive=True,
|
| 192 |
-
elem_id="single-select-radio"
|
| 193 |
-
)
|
| 194 |
-
with gr.Tab("Audio Prompt"):
|
| 195 |
-
prompt_audio = gr.Audio(
|
| 196 |
-
label="Prompt Audio (Optional)",
|
| 197 |
-
type="filepath",
|
| 198 |
-
elem_id="audio-prompt"
|
| 199 |
-
)
|
| 200 |
with gr.Tab("Text Prompt"):
|
| 201 |
gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)")
|
| 202 |
description = gr.Textbox(
|
|
@@ -206,6 +192,20 @@ lyrics
|
|
| 206 |
lines=1,
|
| 207 |
max_lines=2
|
| 208 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
with gr.Accordion("Advanced Config", open=False):
|
| 211 |
cfg_coef = gr.Slider(
|
|
|
|
| 183 |
)
|
| 184 |
|
| 185 |
with gr.Tabs(elem_id="extra-tabs"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
with gr.Tab("Text Prompt"):
|
| 187 |
gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)")
|
| 188 |
description = gr.Textbox(
|
|
|
|
| 192 |
lines=1,
|
| 193 |
max_lines=2
|
| 194 |
)
|
| 195 |
+
with gr.Tab("Audio Prompt"):
|
| 196 |
+
prompt_audio = gr.Audio(
|
| 197 |
+
label="Prompt Audio (Optional)",
|
| 198 |
+
type="filepath",
|
| 199 |
+
elem_id="audio-prompt"
|
| 200 |
+
)
|
| 201 |
+
with gr.Tab("Genre Select"):
|
| 202 |
+
genre = gr.Radio(
|
| 203 |
+
choices=["Auto", "Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera"],
|
| 204 |
+
label="Genre Select(Optional)",
|
| 205 |
+
value="Auto",
|
| 206 |
+
interactive=True,
|
| 207 |
+
elem_id="single-select-radio"
|
| 208 |
+
)
|
| 209 |
|
| 210 |
with gr.Accordion("Advanced Config", open=False):
|
| 211 |
cfg_coef = gr.Slider(
|
levo_inference.py
CHANGED
|
@@ -45,7 +45,7 @@ class LeVoInference(torch.nn.Module):
|
|
| 45 |
model=self.cfg.lm_checkpoint,
|
| 46 |
trust_remote_code=True,
|
| 47 |
tensor_parallel_size=self.cfg.vllm.device_num,
|
| 48 |
-
enforce_eager=
|
| 49 |
dtype="bfloat16",
|
| 50 |
gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
|
| 51 |
tokenizer=None,
|
|
|
|
| 45 |
model=self.cfg.lm_checkpoint,
|
| 46 |
trust_remote_code=True,
|
| 47 |
tensor_parallel_size=self.cfg.vllm.device_num,
|
| 48 |
+
enforce_eager=True,
|
| 49 |
dtype="bfloat16",
|
| 50 |
gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
|
| 51 |
tokenizer=None,
|
vllm_hacked/v1/sample/sampler.py
CHANGED
|
@@ -187,7 +187,10 @@ class Sampler(nn.Module):
|
|
| 187 |
# Avoid division by zero if there are greedy requests.
|
| 188 |
if not all_random:
|
| 189 |
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
| 193 |
return logits.argmax(dim=-1).view(-1)
|
|
|
|
| 187 |
# Avoid division by zero if there are greedy requests.
|
| 188 |
if not all_random:
|
| 189 |
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
| 190 |
+
try:
|
| 191 |
+
return logits.div_(temp.view(-1, 1))
|
| 192 |
+
except:
|
| 193 |
+
return logits.div_(temp.unsqueeze(dim=1))
|
| 194 |
|
| 195 |
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
| 196 |
return logits.argmax(dim=-1).view(-1)
|