MihaiPopa-1 commited on
Commit
6fb8d93
·
verified ·
1 Parent(s): 85393f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -37
app.py CHANGED
@@ -1,81 +1,89 @@
1
  import torch
2
  import torchaudio
3
- from focal_codec.focal_codec import FocalCodec
4
  import gradio as gr
5
- import os # Need this for file path management
6
- import tempfile # A good way to manage temporary files in Gradio Spaces
 
7
 
8
- # Define the model ID for the 0.16 kbps codec
9
- MODEL_ID = "lucadellalib/focalcodec_12_5hz"
10
 
11
- # Load the model globally when the app starts
12
  try:
13
- model = FocalCodec.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
14
  if torch.cuda.is_available():
15
- model.cuda()
16
  except Exception as e:
17
- print(f"Error loading model: {e}")
18
- model = None
19
 
20
  def encode_decode_focal(audio_input):
21
  """
22
  Processes input audio through the 160 bps FocalCodec, saves the tokens,
23
  and returns both the decoded WAV and the path to the FC file for download.
24
  """
25
- if model is None:
26
  return (16000, None), None
27
 
28
  sr, wav_numpy = audio_input
29
 
30
- # Convert numpy to torch tensor and ensure float32, mono channel
31
- wav = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
32
- if wav.shape > 1: # Convert stereo to mono by taking the first channel
33
- wav = wav[:, 0].unsqueeze(0)
34
-
35
- # Resample to 16kHz if necessary (FocalCodec requires 16k input)
36
- if sr != 16000:
37
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
38
- wav = resampler(wav)
 
 
39
 
40
  if torch.cuda.is_available():
41
- wav = wav.cuda()
42
 
43
  # --- Process (Encode and Decode) ---
44
  with torch.no_grad():
45
- # Encode returns codes and bandwidth
46
- codes, bandwidth = model.encode(wav)
47
- # Decode returns the reconstructed waveform
48
- decoded_wav = model.decode(codes)
 
49
 
50
  # --- Save the compressed tokens to a temporary .fc file ---
51
- # Use tempfile to ensure safe file management in a shared environment
52
  temp_dir = tempfile.mkdtemp()
53
  fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
54
- torch.save(codes, fc_file_path)
 
55
 
56
- print(f"Codes saved to {fc_file_path}")
57
 
58
  # Move audio back to CPU for Gradio output and formatting
59
- decoded_wav_output = decoded_wav.cpu().numpy().squeeze()
 
60
 
61
- # Return both the audio tuple and the file path string
62
- return (16000, decoded_wav_output), fc_file_path
63
 
64
- # --- Gradio Interface ---
65
  with gr.Blocks() as iface:
66
- gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_ID.split('/')[-1]})")
67
- gr.Markdown("Test the lowest bitrate neural speech codec! This model is optimized ONLY for speech. Upload your audio or record your voice.")
68
 
69
  with gr.Row():
70
  audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
71
 
72
  with gr.Column():
73
  audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
74
- # The gr.File component handles the download functionality
75
  file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])
76
 
77
- # Map the function to the components
78
- # We use a button explicitly to manage the output flow better than gr.Interface
79
  process_button = gr.Button("Process Audio", variant="primary")
80
  process_button.click(
81
  fn=encode_decode_focal,
 
1
  import torch
2
  import torchaudio
 
3
  import gradio as gr
4
+ import os
5
+ import tempfile
6
+ import numpy as np
7
 
8
+ # Define the model ID for the 0.16 kbps codec config
9
+ MODEL_CONFIG = "lucadellalib/focalcodec_12_5hz"
10
 
11
+ # Load the model globally using torch.hub
12
  try:
13
+ # torch.hub handles cloning the repo internally
14
+ codec = torch.hub.load(
15
+ repo_or_dir="lucadellalib/focalcodec",
16
+ model="focalcodec",
17
+ config=MODEL_CONFIG,
18
+ force_reload=False # Use cached version after first load
19
+ )
20
+ codec.eval().requires_grad_(False) # Set to evaluation mode
21
+
22
  if torch.cuda.is_available():
23
+ codec.cuda()
24
  except Exception as e:
25
+ print(f"Error loading model via torch.hub: {e}")
26
+ codec = None
27
 
28
  def encode_decode_focal(audio_input):
29
  """
30
  Processes input audio through the 160 bps FocalCodec, saves the tokens,
31
  and returns both the decoded WAV and the path to the FC file for download.
32
  """
33
+ if codec is None:
34
  return (16000, None), None
35
 
36
  sr, wav_numpy = audio_input
37
 
38
+ # Convert numpy to torch tensor and ensure float32
39
+ sig = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
40
+
41
+ # Resample input audio to the sample rate required by the codec (16kHz)
42
+ if sr != codec.sample_rate_input:
43
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=codec.sample_rate_input)
44
+ sig = resampler(sig)
45
+
46
+ # Ensure mono channel if needed
47
+ if sig.shape[0] > 1:
48
+ sig = sig[0, :].unsqueeze(0)
49
 
50
  if torch.cuda.is_available():
51
+ sig = sig.cuda()
52
 
53
  # --- Process (Encode and Decode) ---
54
  with torch.no_grad():
55
+ # 1. Encode signal to discrete tokens (the compressed data)
56
+ toks = codec.sig_to_toks(sig)
57
+
58
+ # 2. Decode tokens back into a waveform
59
+ rec_sig = codec.toks_to_sig(toks)
60
 
61
  # --- Save the compressed tokens to a temporary .fc file ---
 
62
  temp_dir = tempfile.mkdtemp()
63
  fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
64
+ # Save the tokens tensor
65
+ torch.save(toks, fc_file_path)
66
 
67
+ print(f"Tokens saved to {fc_file_path}")
68
 
69
  # Move audio back to CPU for Gradio output and formatting
70
+ # Note: Codec output is already at sample_rate_input (16kHz)
71
+ decoded_wav_output = rec_sig.cpu().numpy().squeeze()
72
 
73
+ return (codec.sample_rate_output, decoded_wav_output), fc_file_path
 
74
 
75
+ # --- Gradio Interface (Use the same Blocks interface as before) ---
76
  with gr.Blocks() as iface:
77
+ gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_CONFIG.split('/')[-1]})")
78
+ gr.Markdown("Test the lowest bitrate neural speech codec! Optimized ONLY for speech. Upload your audio or record your voice.")
79
 
80
  with gr.Row():
81
  audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
82
 
83
  with gr.Column():
84
  audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
 
85
  file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])
86
 
 
 
87
  process_button = gr.Button("Process Audio", variant="primary")
88
  process_button.click(
89
  fn=encode_decode_focal,