root commited on
Commit
eb8bfb7
·
1 Parent(s): 09e3553

compatible with L40

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -1
  2. app.py +14 -14
  3. levo_inference.py +1 -1
  4. vllm_hacked/v1/sample/sampler.py +4 -1
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM witszhang/songgeneration_vllm:v1
2
 
3
  USER root
4
 
 
1
+ FROM witszhang/songgeneration_vllm:v2
2
 
3
  USER root
4
 
app.py CHANGED
@@ -183,20 +183,6 @@ lyrics
183
  )
184
 
185
  with gr.Tabs(elem_id="extra-tabs"):
186
- with gr.Tab("Genre Select"):
187
- genre = gr.Radio(
188
- choices=["Auto", "Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera"],
189
- label="Genre Select(Optional)",
190
- value="Auto",
191
- interactive=True,
192
- elem_id="single-select-radio"
193
- )
194
- with gr.Tab("Audio Prompt"):
195
- prompt_audio = gr.Audio(
196
- label="Prompt Audio (Optional)",
197
- type="filepath",
198
- elem_id="audio-prompt"
199
- )
200
  with gr.Tab("Text Prompt"):
201
  gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)")
202
  description = gr.Textbox(
@@ -206,6 +192,20 @@ lyrics
206
  lines=1,
207
  max_lines=2
208
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  with gr.Accordion("Advanced Config", open=False):
211
  cfg_coef = gr.Slider(
 
183
  )
184
 
185
  with gr.Tabs(elem_id="extra-tabs"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  with gr.Tab("Text Prompt"):
187
  gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)")
188
  description = gr.Textbox(
 
192
  lines=1,
193
  max_lines=2
194
  )
195
+ with gr.Tab("Audio Prompt"):
196
+ prompt_audio = gr.Audio(
197
+ label="Prompt Audio (Optional)",
198
+ type="filepath",
199
+ elem_id="audio-prompt"
200
+ )
201
+ with gr.Tab("Genre Select"):
202
+ genre = gr.Radio(
203
+ choices=["Auto", "Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera"],
204
+ label="Genre Select(Optional)",
205
+ value="Auto",
206
+ interactive=True,
207
+ elem_id="single-select-radio"
208
+ )
209
 
210
  with gr.Accordion("Advanced Config", open=False):
211
  cfg_coef = gr.Slider(
levo_inference.py CHANGED
@@ -45,7 +45,7 @@ class LeVoInference(torch.nn.Module):
45
  model=self.cfg.lm_checkpoint,
46
  trust_remote_code=True,
47
  tensor_parallel_size=self.cfg.vllm.device_num,
48
- enforce_eager=False,
49
  dtype="bfloat16",
50
  gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
51
  tokenizer=None,
 
45
  model=self.cfg.lm_checkpoint,
46
  trust_remote_code=True,
47
  tensor_parallel_size=self.cfg.vllm.device_num,
48
+ enforce_eager=True,
49
  dtype="bfloat16",
50
  gpu_memory_utilization=self.cfg.vllm.gpu_memory_utilization,
51
  tokenizer=None,
vllm_hacked/v1/sample/sampler.py CHANGED
@@ -187,7 +187,10 @@ class Sampler(nn.Module):
187
  # Avoid division by zero if there are greedy requests.
188
  if not all_random:
189
  temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
190
- return logits.div_(temp.unsqueeze(dim=1))
 
 
 
191
 
192
  def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
193
  return logits.argmax(dim=-1).view(-1)
 
187
  # Avoid division by zero if there are greedy requests.
188
  if not all_random:
189
  temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
190
+ try:
191
+ return logits.div_(temp.view(-1, 1))
192
+ except:
193
+ return logits.div_(temp.unsqueeze(dim=1))
194
 
195
  def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
196
  return logits.argmax(dim=-1).view(-1)