MihaiPopa-1 commited on
Commit
6b1b74f
·
verified ·
1 Parent(s): 2301f40

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from focal_codec.focal_codec import FocalCodec
4
+ import gradio as gr
5
+ import os # Need this for file path management
6
+ import tempfile # A good way to manage temporary files in Gradio Spaces
7
+
8
+ # Define the model ID for the 0.16 kbps codec
9
+ MODEL_ID = "lucadellalib/focalcodec_12_5hz"
10
+
11
+ # Load the model globally when the app starts
12
+ try:
13
+ model = FocalCodec.from_pretrained(MODEL_ID)
14
+ if torch.cuda.is_available():
15
+ model.cuda()
16
+ except Exception as e:
17
+ print(f"Error loading model: {e}")
18
+ model = None
19
+
20
+ def encode_decode_focal(audio_input):
21
+ """
22
+ Processes input audio through the 160 bps FocalCodec, saves the tokens,
23
+ and returns both the decoded WAV and the path to the FC file for download.
24
+ """
25
+ if model is None:
26
+ return (16000, None), None
27
+
28
+ sr, wav_numpy = audio_input
29
+
30
+ # Convert numpy to torch tensor and ensure float32, mono channel
31
+ wav = torch.tensor(wav_numpy, dtype=torch.float32).unsqueeze(0)
32
+ if wav.shape > 1: # Convert stereo to mono by taking the first channel
33
+ wav = wav[:, 0].unsqueeze(0)
34
+
35
+ # Resample to 16kHz if necessary (FocalCodec requires 16k input)
36
+ if sr != 16000:
37
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
38
+ wav = resampler(wav)
39
+
40
+ if torch.cuda.is_available():
41
+ wav = wav.cuda()
42
+
43
+ # --- Process (Encode and Decode) ---
44
+ with torch.no_grad():
45
+ # Encode returns codes and bandwidth
46
+ codes, bandwidth = model.encode(wav)
47
+ # Decode returns the reconstructed waveform
48
+ decoded_wav = model.decode(codes)
49
+
50
+ # --- Save the compressed tokens to a temporary .fc file ---
51
+ # Use tempfile to ensure safe file management in a shared environment
52
+ temp_dir = tempfile.mkdtemp()
53
+ fc_file_path = os.path.join(temp_dir, "compressed_tokens.fc")
54
+ torch.save(codes, fc_file_path)
55
+
56
+ print(f"Codes saved to {fc_file_path}")
57
+
58
+ # Move audio back to CPU for Gradio output and formatting
59
+ decoded_wav_output = decoded_wav.cpu().numpy().squeeze()
60
+
61
+ # Return both the audio tuple and the file path string
62
+ return (16000, decoded_wav_output), fc_file_path
63
+
64
+ # --- Gradio Interface ---
65
+ with gr.Blocks() as iface:
66
+ gr.Markdown(f"## FocalCodec at 160 bps ({MODEL_ID.split('/')[-1]})")
67
+ gr.Markdown("Test the lowest bitrate neural speech codec! This model is optimized ONLY for speech. Upload your audio or record your voice.")
68
+
69
+ with gr.Row():
70
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Input Audio (Speech Only Recommended)")
71
+
72
+ with gr.Column():
73
+ audio_output = gr.Audio(type="numpy", label="Decoded Output Audio (160 bps)")
74
+ # The gr.File component handles the download functionality
75
+ file_output = gr.File(label="Download Compressed Tokens (*.fc file)", file_count="single", file_types=[".fc"])
76
+
77
+ # Map the function to the components
78
+ # We use a button explicitly to manage the output flow better than gr.Interface
79
+ process_button = gr.Button("Process Audio", variant="primary")
80
+ process_button.click(
81
+ fn=encode_decode_focal,
82
+ inputs=[audio_input],
83
+ outputs=[audio_output, file_output]
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ iface.launch()