ricco / app.py
vcollos's picture
Update app.py
be89d29 verified
# =============================================================================
# RICCO - GERADOR DE TRILHAS SONORAS
# App Gradio para Hugging Face Space com ZeroGPU
# =============================================================================
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
import librosa
import soundfile as sf
from transformers import AutoTokenizer
import tempfile
import requests
import spaces # Para ZeroGPU
# =============================================================================
# ARQUITETURA DO MODELO RICCO
# =============================================================================
class TextToMusicGenerator(nn.Module):
"""Arquitetura do modelo Ricco (com discriminator para compatibilidade)"""
def __init__(self, config):
super().__init__()
self.config = config
# Text encoder (BERT)
from transformers import AutoModel
self.text_encoder = AutoModel.from_pretrained('bert-base-uncased')
# Projeção do texto
self.text_projection = nn.Sequential(
nn.Linear(config['text_embed_dim'], config['latent_dim']),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(config['latent_dim'], config['latent_dim'])
)
# Generator
self.generator = nn.Sequential(
nn.Linear(config['latent_dim'] * 2, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(2048, 64 * 8 * 39),
nn.ReLU()
)
# Decoder convolucional
self.conv_decoder = nn.Sequential(
nn.ConvTranspose2d(64, 128, (4, 4), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 256, (4, 4), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, (4, 4), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, (4, 4), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, 3, padding=1),
nn.Tanh()
)
# Discriminator (para compatibilidade com o modelo treinado)
self.discriminator = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(256, 1),
nn.Sigmoid()
)
# Congelar BERT
for param in self.text_encoder.parameters():
param.requires_grad = False
def encode_text(self, text_tokens):
with torch.no_grad():
text_features = self.text_encoder(**text_tokens).last_hidden_state
text_features = text_features[:, 0, :]
return self.text_projection(text_features)
def generate_music(self, text_tokens, noise=None):
batch_size = text_tokens['input_ids'].shape[0]
text_latent = self.encode_text(text_tokens)
if noise is None:
noise = torch.randn(batch_size, self.config['latent_dim']).to(text_latent.device)
combined = torch.cat([text_latent, noise], dim=1)
features = self.generator(combined)
features = features.view(batch_size, 64, 8, 39)
mel_spec = self.conv_decoder(features)
return mel_spec
def melspec_to_audio(self, melspec):
"""Converter mel-spectrogram para áudio"""
melspec = (melspec + 1.0) / 2.0
melspec_power = librosa.db_to_power(melspec * 80 - 80)
audio = librosa.feature.inverse.mel_to_audio(
melspec_power,
sr=self.config['sample_rate'],
hop_length=self.config['hop_length'],
n_fft=self.config['n_fft'],
n_iter=100
)
return audio
# =============================================================================
# CARREGAR MODELO RICCO
# =============================================================================
def load_ricco_model():
"""Carregar modelo Ricco do HuggingFace"""
print("🎵 Carregando modelo Ricco...")
# Configuração do modelo
config = {
'sample_rate': 32000,
'duration': 10,
'n_mels': 128,
'hop_length': 512,
'n_fft': 2048,
'latent_dim': 512,
'text_embed_dim': 768,
'seq_len': 624
}
# Criar modelo
model = TextToMusicGenerator(config)
# Baixar pesos do HuggingFace
model_url = "https://huggingface.co/vcollos/ricco/resolve/main/text2music_final.pt"
try:
# Download temporário
print("📥 Baixando modelo...")
response = requests.get(model_url)
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(response.content)
tmp.flush()
# Carregar pesos
checkpoint = torch.load(tmp.name, map_location='cpu')
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
print("✅ Modelo Ricco carregado com sucesso!")
return model, config
except Exception as e:
print(f"❌ Erro ao carregar modelo: {e}")
return None, None
# Carregar modelo e tokenizer
model, config = load_ricco_model()
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if model:
model = model.to(device)
# =============================================================================
# ESTILOS DISPONÍVEIS
# =============================================================================
ESTILOS_RICCO = [
"trilha instrumental de estilo acao",
"trilha instrumental de estilo acoustic",
"trilha instrumental de estilo acustico",
"trilha instrumental de estilo drama",
"trilha instrumental de estilo eletronica",
"trilha instrumental de estilo emocional",
"trilha instrumental de estilo epico",
"trilha instrumental de estilo hip hop",
"trilha instrumental de estilo house",
"trilha instrumental de estilo jornalistico",
"trilha instrumental de estilo misterio",
"trilha instrumental de estilo motivacional",
"trilha instrumental de estilo oracao",
"trilha instrumental de estilo pop",
"trilha instrumental de estilo rock",
"trilha instrumental de estilo suspense",
"trilha instrumental de estilo tenso",
"trilha instrumental de estilo trap",
"trilha instrumental de estilo violao",
"trilha instrumental de estilo xote"
]
# =============================================================================
# FUNÇÃO DE GERAÇÃO COM ZEROGPU
# =============================================================================
@spaces.GPU
def gerar_trilha_ricco(prompt_selecionado, prompt_customizado, usar_customizado, temperatura):
"""Gerar trilha com o modelo Ricco usando ZeroGPU"""
if not model:
return None, "❌ Modelo não carregado. Tente recarregar a página."
try:
# Escolher prompt
if usar_customizado and prompt_customizado.strip():
prompt = prompt_customizado.strip()
else:
prompt = prompt_selecionado
print(f"🎵 Gerando: {prompt}")
# Mover modelo para GPU se disponível
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_gpu = model.to(device)
# Tokenizar
text_tokens = tokenizer(
prompt,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
).to(device)
# Gerar com variação controlada
with torch.no_grad():
batch_size = text_tokens['input_ids'].shape[0]
text_latent = model_gpu.encode_text(text_tokens)
# Noise com temperatura
noise = torch.randn(batch_size, config['latent_dim']).to(device) * temperatura
combined = torch.cat([text_latent, noise], dim=1)
features = model_gpu.generator(combined)
features = features.view(batch_size, 64, 8, 39)
mel_spec = model_gpu.conv_decoder(features)
# Converter para áudio
mel_np = mel_spec[0, 0].cpu().numpy()
audio = model_gpu.melspec_to_audio(mel_np)
# Normalizar
audio = audio / (np.max(np.abs(audio)) + 1e-8)
# Salvar temporariamente
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
sf.write(tmp.name, audio, config['sample_rate'])
return tmp.name, f"🎵 Trilha gerada: {prompt} | GPU: {device}"
except Exception as e:
return None, f"❌ Erro na geração: {str(e)}"
# =============================================================================
# INTERFACE GRADIO
# =============================================================================
# CSS customizado
css = """
.container {
max-width: 1200px;
margin: auto;
}
.header {
text-align: center;
background: linear-gradient(45deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 2rem;
border-radius: 15px;
margin-bottom: 2rem;
}
.controls {
background: #f8f9fa;
padding: 1.5rem;
border-radius: 10px;
border: 1px solid #e9ecef;
}
.footer {
text-align: center;
margin-top: 2rem;
color: #6c757d;
font-size: 0.9rem;
}
"""
# Interface principal
with gr.Blocks(css=css, title="🎵 Ricco - Gerador de Trilhas") as demo:
# Header
gr.HTML("""
<div class="header">
<h1>🎵 Ricco - Gerador de Trilhas Sonoras para TV</h1>
<p>Gere trilhas instrumentais profissionais em 20 estilos diferentes usando IA</p>
</div>
""")
with gr.Row():
# Controles
with gr.Column(scale=1, elem_classes="controls"):
gr.Markdown("### 🎛️ Controles de Geração")
prompt_dropdown = gr.Dropdown(
choices=ESTILOS_RICCO,
value=ESTILOS_RICCO[6], # Épico como padrão
label="🎭 Estilo da Trilha",
interactive=True
)
usar_custom = gr.Checkbox(
label="✏️ Usar prompt personalizado",
value=False
)
prompt_custom = gr.Textbox(
label="🎨 Prompt Personalizado",
placeholder="Ex: trilha instrumental de estilo jazz suave e relaxante",
visible=False,
lines=2
)
temperatura = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="🌡️ Criatividade (Temperatura)",
info="Maior = mais variação, Menor = mais consistente"
)
gerar_btn = gr.Button(
"🎵 Gerar Trilha Sonora",
variant="primary",
size="lg"
)
# Resultados
with gr.Column(scale=1):
gr.Markdown("### 🎶 Trilha Gerada")
status = gr.Textbox(
label="📊 Status",
interactive=False,
value="Pronto para gerar!"
)
audio_output = gr.Audio(
label="🔊 Sua Trilha Sonora",
type="filepath"
)
gr.Markdown("""
**💡 Dica:** Cada geração dura ~10 segundos e é única!
Experimente diferentes estilos ou ajuste a criatividade.
""")
# Exemplos predefinidos
gr.Markdown("### 🎯 Exemplos Rápidos")
examples = gr.Examples(
examples=[
["trilha instrumental de estilo epico", "", False, 1.0],
["trilha instrumental de estilo suspense", "", False, 1.2],
["trilha instrumental de estilo pop", "", False, 0.8],
["trilha instrumental de estilo rock", "", False, 1.1],
["trilha instrumental de estilo emocional", "", False, 0.9]
],
inputs=[prompt_dropdown, prompt_custom, usar_custom, temperatura],
outputs=[audio_output, status],
fn=gerar_trilha_ricco,
cache_examples=False
)
# Footer
gr.HTML("""
<div class="footer">
<p>🤖 Modelo Ricco treinado em 2.300 trilhas profissionais |
⚡ Geração em tempo real |
🎬 Ideal para TV, vídeos e podcasts</p>
<p>Criado por <a href="https://huggingface.co/vcollos" target="_blank">vcollos</a></p>
</div>
""")
# Eventos
usar_custom.change(
lambda x: gr.update(visible=x),
inputs=[usar_custom],
outputs=[prompt_custom]
)
gerar_btn.click(
gerar_trilha_ricco,
inputs=[prompt_dropdown, prompt_custom, usar_custom, temperatura],
outputs=[audio_output, status]
)
# =============================================================================
# LAUNCH
# =============================================================================
if __name__ == "__main__":
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)