Spaces:
Runtime error
Runtime error
Working Nari labs code
Browse files
app.py
CHANGED
|
@@ -7,13 +7,15 @@ from dia.model import Dia
|
|
| 7 |
from huggingface_hub import InferenceClient
|
| 8 |
import numpy as np
|
| 9 |
from transformers import set_seed
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Hardcoded podcast subject
|
| 12 |
PODCAST_SUBJECT = "The future of AI and its impact on society"
|
| 13 |
|
| 14 |
# Initialize the inference client
|
| 15 |
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
|
| 16 |
-
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="
|
| 17 |
|
| 18 |
# Queue for audio streaming
|
| 19 |
audio_queue = queue.Queue()
|
|
@@ -36,15 +38,9 @@ Now go on, make 5 minutes of podcast.
|
|
| 36 |
|
| 37 |
def split_podcast_into_chunks(podcast_text, chunk_size=3):
|
| 38 |
lines = podcast_text.strip().split("\n")
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
for i in range(0, len(lines), chunk_size):
|
| 42 |
-
chunk = "\n".join(lines[i : i + chunk_size])
|
| 43 |
-
chunks.append(chunk)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def postprocess_audio(output_audio_np, speed_factor: float=0.94):
|
| 48 |
"""Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
|
| 49 |
# Get sample rate from the loaded DAC model
|
| 50 |
output_sr = 44100
|
|
@@ -98,6 +94,7 @@ def process_audio_chunks(podcast_text):
|
|
| 98 |
chunks = split_podcast_into_chunks(podcast_text)
|
| 99 |
sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
|
| 100 |
for chunk in chunks:
|
|
|
|
| 101 |
if stop_signal.is_set():
|
| 102 |
break
|
| 103 |
set_seed(42)
|
|
@@ -117,26 +114,21 @@ def process_audio_chunks(podcast_text):
|
|
| 117 |
def stream_audio_generator(podcast_text):
|
| 118 |
"""Creates a generator that yields audio chunks for streaming"""
|
| 119 |
stop_signal.clear()
|
|
|
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# Yield the audio chunk with sample rate
|
| 135 |
-
print(chunk)
|
| 136 |
-
yield chunk
|
| 137 |
-
|
| 138 |
-
except Exception as e:
|
| 139 |
-
print(f"Error in streaming: {e}")
|
| 140 |
|
| 141 |
|
| 142 |
def stop_generation():
|
|
|
|
| 7 |
from huggingface_hub import InferenceClient
|
| 8 |
import numpy as np
|
| 9 |
from transformers import set_seed
|
| 10 |
+
import io, soundfile as sf
|
| 11 |
+
|
| 12 |
|
| 13 |
# Hardcoded podcast subject
|
| 14 |
PODCAST_SUBJECT = "The future of AI and its impact on society"
|
| 15 |
|
| 16 |
# Initialize the inference client
|
| 17 |
client = InferenceClient("meta-llama/Llama-3.3-70B-Instruct", provider="cerebras", token=os.getenv("HF_TOKEN"))
|
| 18 |
+
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
|
| 19 |
|
| 20 |
# Queue for audio streaming
|
| 21 |
audio_queue = queue.Queue()
|
|
|
|
| 38 |
|
| 39 |
def split_podcast_into_chunks(podcast_text, chunk_size=3):
|
| 40 |
lines = podcast_text.strip().split("\n")
|
| 41 |
+
return ["\n".join(lines[i : i + chunk_size]) for i in range(0, len(lines), chunk_size)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
def postprocess_audio(output_audio_np, speed_factor: float=0.8):
|
|
|
|
|
|
|
| 44 |
"""Taken from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py"""
|
| 45 |
# Get sample rate from the loaded DAC model
|
| 46 |
output_sr = 44100
|
|
|
|
| 94 |
chunks = split_podcast_into_chunks(podcast_text)
|
| 95 |
sample_rate = 44100 # Modified from https://huggingface.co/spaces/nari-labs/Dia-1.6B/blob/main/app.py has 44100
|
| 96 |
for chunk in chunks:
|
| 97 |
+
print(f"Processing chunk: {chunk}")
|
| 98 |
if stop_signal.is_set():
|
| 99 |
break
|
| 100 |
set_seed(42)
|
|
|
|
| 114 |
def stream_audio_generator(podcast_text):
|
| 115 |
"""Creates a generator that yields audio chunks for streaming"""
|
| 116 |
stop_signal.clear()
|
| 117 |
+
threading.Thread(target=process_audio_chunks, args=(podcast_text,)).start()
|
| 118 |
|
| 119 |
+
while True:
|
| 120 |
+
chunk = audio_queue.get()
|
| 121 |
+
if chunk is None:
|
| 122 |
+
break
|
| 123 |
+
sr, data = chunk # the tuple you produced earlier
|
| 124 |
+
|
| 125 |
+
# Encode the numpy array into a WAV blob
|
| 126 |
+
buf = io.BytesIO()
|
| 127 |
+
sf.write(buf, data.astype(np.float32) / 32768.0, sr, format="wav")
|
| 128 |
+
buf.seek(0)
|
| 129 |
+
buffer = buf.getvalue()
|
| 130 |
+
print("PRINTING BUFFER:", buffer)
|
| 131 |
+
yield buffer# <-- bytes, so the browser can play it
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
def stop_generation():
|