|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
from transformers import AutoTokenizer |
|
|
import tempfile |
|
|
import requests |
|
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextToMusicGenerator(nn.Module): |
|
|
"""Arquitetura do modelo Ricco (com discriminator para compatibilidade)""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
from transformers import AutoModel |
|
|
self.text_encoder = AutoModel.from_pretrained('bert-base-uncased') |
|
|
|
|
|
|
|
|
self.text_projection = nn.Sequential( |
|
|
nn.Linear(config['text_embed_dim'], config['latent_dim']), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(config['latent_dim'], config['latent_dim']) |
|
|
) |
|
|
|
|
|
|
|
|
self.generator = nn.Sequential( |
|
|
nn.Linear(config['latent_dim'] * 2, 1024), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(1024, 2048), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(2048, 64 * 8 * 39), |
|
|
nn.ReLU() |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_decoder = nn.Sequential( |
|
|
nn.ConvTranspose2d(64, 128, (4, 4), stride=(2, 2), padding=(1, 1)), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(128, 256, (4, 4), stride=(2, 2), padding=(1, 1)), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(256, 128, (4, 4), stride=(2, 2), padding=(1, 1)), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(), |
|
|
nn.ConvTranspose2d(128, 64, (4, 4), stride=(2, 2), padding=(1, 1)), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(64, 32, 3, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(32, 1, 3, padding=1), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
|
|
|
self.discriminator = nn.Sequential( |
|
|
nn.Conv2d(1, 32, 4, stride=2, padding=1), |
|
|
nn.LeakyReLU(0.2), |
|
|
|
|
|
nn.Conv2d(32, 64, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.LeakyReLU(0.2), |
|
|
|
|
|
nn.Conv2d(64, 128, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.LeakyReLU(0.2), |
|
|
|
|
|
nn.Conv2d(128, 256, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.LeakyReLU(0.2), |
|
|
|
|
|
nn.AdaptiveAvgPool2d(1), |
|
|
nn.Flatten(), |
|
|
nn.Linear(256, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
for param in self.text_encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def encode_text(self, text_tokens): |
|
|
with torch.no_grad(): |
|
|
text_features = self.text_encoder(**text_tokens).last_hidden_state |
|
|
text_features = text_features[:, 0, :] |
|
|
return self.text_projection(text_features) |
|
|
|
|
|
def generate_music(self, text_tokens, noise=None): |
|
|
batch_size = text_tokens['input_ids'].shape[0] |
|
|
text_latent = self.encode_text(text_tokens) |
|
|
|
|
|
if noise is None: |
|
|
noise = torch.randn(batch_size, self.config['latent_dim']).to(text_latent.device) |
|
|
|
|
|
combined = torch.cat([text_latent, noise], dim=1) |
|
|
features = self.generator(combined) |
|
|
features = features.view(batch_size, 64, 8, 39) |
|
|
mel_spec = self.conv_decoder(features) |
|
|
|
|
|
return mel_spec |
|
|
|
|
|
def melspec_to_audio(self, melspec): |
|
|
"""Converter mel-spectrogram para áudio""" |
|
|
melspec = (melspec + 1.0) / 2.0 |
|
|
melspec_power = librosa.db_to_power(melspec * 80 - 80) |
|
|
|
|
|
audio = librosa.feature.inverse.mel_to_audio( |
|
|
melspec_power, |
|
|
sr=self.config['sample_rate'], |
|
|
hop_length=self.config['hop_length'], |
|
|
n_fft=self.config['n_fft'], |
|
|
n_iter=100 |
|
|
) |
|
|
return audio |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_ricco_model(): |
|
|
"""Carregar modelo Ricco do HuggingFace""" |
|
|
print("🎵 Carregando modelo Ricco...") |
|
|
|
|
|
|
|
|
config = { |
|
|
'sample_rate': 32000, |
|
|
'duration': 10, |
|
|
'n_mels': 128, |
|
|
'hop_length': 512, |
|
|
'n_fft': 2048, |
|
|
'latent_dim': 512, |
|
|
'text_embed_dim': 768, |
|
|
'seq_len': 624 |
|
|
} |
|
|
|
|
|
|
|
|
model = TextToMusicGenerator(config) |
|
|
|
|
|
|
|
|
model_url = "https://huggingface.co/vcollos/ricco/resolve/main/text2music_final.pt" |
|
|
|
|
|
try: |
|
|
|
|
|
print("📥 Baixando modelo...") |
|
|
response = requests.get(model_url) |
|
|
|
|
|
with tempfile.NamedTemporaryFile() as tmp: |
|
|
tmp.write(response.content) |
|
|
tmp.flush() |
|
|
|
|
|
|
|
|
checkpoint = torch.load(tmp.name, map_location='cpu') |
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
model.eval() |
|
|
print("✅ Modelo Ricco carregado com sucesso!") |
|
|
return model, config |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Erro ao carregar modelo: {e}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
model, config = load_ricco_model() |
|
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
if model: |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ESTILOS_RICCO = [ |
|
|
"trilha instrumental de estilo acao", |
|
|
"trilha instrumental de estilo acoustic", |
|
|
"trilha instrumental de estilo acustico", |
|
|
"trilha instrumental de estilo drama", |
|
|
"trilha instrumental de estilo eletronica", |
|
|
"trilha instrumental de estilo emocional", |
|
|
"trilha instrumental de estilo epico", |
|
|
"trilha instrumental de estilo hip hop", |
|
|
"trilha instrumental de estilo house", |
|
|
"trilha instrumental de estilo jornalistico", |
|
|
"trilha instrumental de estilo misterio", |
|
|
"trilha instrumental de estilo motivacional", |
|
|
"trilha instrumental de estilo oracao", |
|
|
"trilha instrumental de estilo pop", |
|
|
"trilha instrumental de estilo rock", |
|
|
"trilha instrumental de estilo suspense", |
|
|
"trilha instrumental de estilo tenso", |
|
|
"trilha instrumental de estilo trap", |
|
|
"trilha instrumental de estilo violao", |
|
|
"trilha instrumental de estilo xote" |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def gerar_trilha_ricco(prompt_selecionado, prompt_customizado, usar_customizado, temperatura): |
|
|
"""Gerar trilha com o modelo Ricco usando ZeroGPU""" |
|
|
|
|
|
if not model: |
|
|
return None, "❌ Modelo não carregado. Tente recarregar a página." |
|
|
|
|
|
try: |
|
|
|
|
|
if usar_customizado and prompt_customizado.strip(): |
|
|
prompt = prompt_customizado.strip() |
|
|
else: |
|
|
prompt = prompt_selecionado |
|
|
|
|
|
print(f"🎵 Gerando: {prompt}") |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model_gpu = model.to(device) |
|
|
|
|
|
|
|
|
text_tokens = tokenizer( |
|
|
prompt, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=128, |
|
|
return_tensors='pt' |
|
|
).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_size = text_tokens['input_ids'].shape[0] |
|
|
text_latent = model_gpu.encode_text(text_tokens) |
|
|
|
|
|
|
|
|
noise = torch.randn(batch_size, config['latent_dim']).to(device) * temperatura |
|
|
|
|
|
combined = torch.cat([text_latent, noise], dim=1) |
|
|
features = model_gpu.generator(combined) |
|
|
features = features.view(batch_size, 64, 8, 39) |
|
|
mel_spec = model_gpu.conv_decoder(features) |
|
|
|
|
|
|
|
|
mel_np = mel_spec[0, 0].cpu().numpy() |
|
|
audio = model_gpu.melspec_to_audio(mel_np) |
|
|
|
|
|
|
|
|
audio = audio / (np.max(np.abs(audio)) + 1e-8) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: |
|
|
sf.write(tmp.name, audio, config['sample_rate']) |
|
|
return tmp.name, f"🎵 Trilha gerada: {prompt} | GPU: {device}" |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"❌ Erro na geração: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
|
.container { |
|
|
max-width: 1200px; |
|
|
margin: auto; |
|
|
} |
|
|
|
|
|
.header { |
|
|
text-align: center; |
|
|
background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
padding: 2rem; |
|
|
border-radius: 15px; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
|
|
|
.controls { |
|
|
background: #f8f9fa; |
|
|
padding: 1.5rem; |
|
|
border-radius: 10px; |
|
|
border: 1px solid #e9ecef; |
|
|
} |
|
|
|
|
|
.footer { |
|
|
text-align: center; |
|
|
margin-top: 2rem; |
|
|
color: #6c757d; |
|
|
font-size: 0.9rem; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, title="🎵 Ricco - Gerador de Trilhas") as demo: |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="header"> |
|
|
<h1>🎵 Ricco - Gerador de Trilhas Sonoras para TV</h1> |
|
|
<p>Gere trilhas instrumentais profissionais em 20 estilos diferentes usando IA</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1, elem_classes="controls"): |
|
|
gr.Markdown("### 🎛️ Controles de Geração") |
|
|
|
|
|
prompt_dropdown = gr.Dropdown( |
|
|
choices=ESTILOS_RICCO, |
|
|
value=ESTILOS_RICCO[6], |
|
|
label="🎭 Estilo da Trilha", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
usar_custom = gr.Checkbox( |
|
|
label="✏️ Usar prompt personalizado", |
|
|
value=False |
|
|
) |
|
|
|
|
|
prompt_custom = gr.Textbox( |
|
|
label="🎨 Prompt Personalizado", |
|
|
placeholder="Ex: trilha instrumental de estilo jazz suave e relaxante", |
|
|
visible=False, |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
temperatura = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="🌡️ Criatividade (Temperatura)", |
|
|
info="Maior = mais variação, Menor = mais consistente" |
|
|
) |
|
|
|
|
|
gerar_btn = gr.Button( |
|
|
"🎵 Gerar Trilha Sonora", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 🎶 Trilha Gerada") |
|
|
|
|
|
status = gr.Textbox( |
|
|
label="📊 Status", |
|
|
interactive=False, |
|
|
value="Pronto para gerar!" |
|
|
) |
|
|
|
|
|
audio_output = gr.Audio( |
|
|
label="🔊 Sua Trilha Sonora", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
**💡 Dica:** Cada geração dura ~10 segundos e é única! |
|
|
Experimente diferentes estilos ou ajuste a criatividade. |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown("### 🎯 Exemplos Rápidos") |
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["trilha instrumental de estilo epico", "", False, 1.0], |
|
|
["trilha instrumental de estilo suspense", "", False, 1.2], |
|
|
["trilha instrumental de estilo pop", "", False, 0.8], |
|
|
["trilha instrumental de estilo rock", "", False, 1.1], |
|
|
["trilha instrumental de estilo emocional", "", False, 0.9] |
|
|
], |
|
|
inputs=[prompt_dropdown, prompt_custom, usar_custom, temperatura], |
|
|
outputs=[audio_output, status], |
|
|
fn=gerar_trilha_ricco, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="footer"> |
|
|
<p>🤖 Modelo Ricco treinado em 2.300 trilhas profissionais | |
|
|
⚡ Geração em tempo real | |
|
|
🎬 Ideal para TV, vídeos e podcasts</p> |
|
|
<p>Criado por <a href="https://huggingface.co/vcollos" target="_blank">vcollos</a></p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
usar_custom.change( |
|
|
lambda x: gr.update(visible=x), |
|
|
inputs=[usar_custom], |
|
|
outputs=[prompt_custom] |
|
|
) |
|
|
|
|
|
gerar_btn.click( |
|
|
gerar_trilha_ricco, |
|
|
inputs=[prompt_dropdown, prompt_custom, usar_custom, temperatura], |
|
|
outputs=[audio_output, status] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
share=False, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860 |
|
|
) |