LyrGen2 / app.py
James-Edmunds's picture
Upload folder using huggingface_hub
2d21a4c verified
"""Main Streamlit application for lyric generation."""
import os
import streamlit as st
from pathlib import Path
from datetime import datetime
# Set environment before other imports
if os.getenv('DEPLOYMENT_MODE') == 'huggingface':
os.environ['HF_HOME'] = '/data/.huggingface'
else:
os.environ['DYLD_LIBRARY_PATH'] = '/usr/local/opt/sqlite/lib'
import re
from src.generator.generator import LyricGenerator
from config.settings import Settings
def format_lyrics(text: str) -> str:
"""Format lyrics: bold section markers, force line breaks."""
text = text.strip()
# Bold section markers like [Verse 1], [Chorus], etc.
text = re.sub(r'\[([^\]]+)\]', r'**[\1]**', text)
# Force markdown line breaks: add two trailing spaces before each newline
text = re.sub(r' *\n', ' \n', text)
return text
def main():
"""Main application function"""
st.set_page_config(
page_title="SongLift LyrGen2",
page_icon="🎵",
layout="wide"
)
st.title("SongLift LyrGen2")
if st.sidebar.button("New Song"):
st.session_state.chat_history = []
st.session_state.current_lyrics = None
st.rerun()
# Show DB stats if available
if "db_stats" in st.session_state:
stats = st.session_state.db_stats
st.sidebar.markdown("---")
st.sidebar.metric("Artists", f"{stats['artists']:,}")
st.sidebar.metric("Songs", f"{stats['songs']:,}")
st.sidebar.metric("Chunks", f"{stats['chunks']:,}")
# Only run startup once per session
if 'initialized' not in st.session_state:
print("===== Application Startup at", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "=====\n")
# Initialize generator
try:
st.info("Initializing generator... this may take a moment.")
print("\n=== Initializing Generator ===")
generator = LyricGenerator()
st.session_state.generator = generator
st.session_state.chat_history = []
st.session_state.current_lyrics = None
st.session_state.initialized = True
# Compute DB stats for sidebar
collection = generator.vector_store._collection
all_meta = collection.get(include=["metadatas"])["metadatas"]
artists = set()
songs = set()
for m in all_meta:
artists.add(m.get("artist", ""))
songs.add((m.get("artist", ""), m.get("song_title", "")))
st.session_state.db_stats = {
"chunks": len(all_meta),
"songs": len(songs),
"artists": len(artists),
}
print("Generator initialized successfully")
st.rerun()
except Exception as e:
st.error(f"Error initializing generator: {str(e)}")
print(f"Error: {str(e)}")
import traceback
traceback.print_exc()
st.stop()
# Check OpenAI API key
if not Settings.OPENAI_API_KEY:
st.error("OpenAI API key not found. Please set OPENAI_API_KEY.")
st.stop()
# Display chat history
for message in st.session_state.chat_history:
user_msg, assistant_msg = message
with st.chat_message("user"):
st.write(user_msg)
with st.chat_message("assistant"):
st.markdown(format_lyrics(assistant_msg))
# Chat interface
user_input = st.chat_input("Enter your prompt (ask for new lyrics or modify existing ones)...")
if user_input:
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
try:
with st.spinner("Generating lyrics..."):
response = st.session_state.generator.generate_lyrics(
user_input,
st.session_state.chat_history
)
# Store the response
lyrics = response['answer']
st.markdown(format_lyrics(lyrics))
st.session_state.current_lyrics = lyrics
# Display sources with content
with st.expander("View Sources and Context"):
# Show top retrieved contexts with snippets
st.write("### Retrieved Contexts")
for detail in response["context_details"]:
st.write(
f"\n**{detail['artist']} - {detail['song']}**"
)
st.text(detail['content'])
st.write("---")
# Show all unique source songs from the chain
st.write("### All Sources Used")
seen_sources = set()
source_docs = response.get("source_documents", [])
for doc in source_docs:
source_key = (
doc.metadata.get('artist', 'Unknown'),
doc.metadata.get('song_title', 'Unknown')
)
if source_key not in seen_sources:
seen_sources.add(source_key)
st.write(f"- {source_key[0]} - {source_key[1]}")
st.write(f"\n*{len(seen_sources)} unique songs from {len({s[0] for s in seen_sources})} artists*")
# Update chat history
st.session_state.chat_history.append((user_input, lyrics))
except Exception as e:
st.error(f"Error generating lyrics: {str(e)}")
if __name__ == "__main__":
main()