File size: 5,676 Bytes
322a04b
ac06d04
4994b71
2fb2290
4f8d0ad
4994b71
d305e69
 
 
 
ac06d04
 
4c377bd
 
d305e69
 
 
ac06d04
4c377bd
2d21a4c
 
4c377bd
 
2d21a4c
 
 
4c377bd
 
ac06d04
 
 
40c5feb
ac06d04
 
 
d305e69
40c5feb
a998f2d
 
 
 
d305e69
e366acd
 
 
 
 
 
 
 
ac8d6e6
 
 
5ccf3cf
ac8d6e6
 
d305e69
 
 
 
 
 
 
e366acd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d305e69
 
ac8d6e6
 
 
5ccf3cf
 
 
d305e69
 
 
 
 
 
ac06d04
 
 
 
 
 
4c377bd
d305e69
ac06d04
d305e69
 
ac06d04
 
 
d305e69
ac06d04
 
d305e69
ac06d04
 
 
 
d305e69
ac06d04
 
4c377bd
ac06d04
d305e69
ac06d04
 
a998f2d
 
ac06d04
 
a998f2d
ac06d04
 
 
d305e69
a998f2d
 
ac06d04
a998f2d
 
ac06d04
a998f2d
 
ac06d04
 
 
a998f2d
 
d305e69
ac06d04
 
d305e69
ac06d04
 
 
 
d305e69
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Main Streamlit application for lyric generation."""
import os
import streamlit as st
from pathlib import Path
from datetime import datetime

# Set environment before other imports
if os.getenv('DEPLOYMENT_MODE') == 'huggingface':
    os.environ['HF_HOME'] = '/data/.huggingface'
else:
    os.environ['DYLD_LIBRARY_PATH'] = '/usr/local/opt/sqlite/lib'

import re

from src.generator.generator import LyricGenerator
from config.settings import Settings


def format_lyrics(text: str) -> str:
    """Format lyrics: bold section markers, force line breaks."""
    text = text.strip()
    # Bold section markers like [Verse 1], [Chorus], etc.
    text = re.sub(r'\[([^\]]+)\]', r'**[\1]**', text)
    # Force markdown line breaks: add two trailing spaces before each newline
    text = re.sub(r' *\n', '  \n', text)
    return text


def main():
    """Main application function"""
    st.set_page_config(
        page_title="SongLift LyrGen2",
        page_icon="🎵",
        layout="wide"
    )

    st.title("SongLift LyrGen2")
    if st.sidebar.button("New Song"):
        st.session_state.chat_history = []
        st.session_state.current_lyrics = None
        st.rerun()

    # Show DB stats if available
    if "db_stats" in st.session_state:
        stats = st.session_state.db_stats
        st.sidebar.markdown("---")
        st.sidebar.metric("Artists", f"{stats['artists']:,}")
        st.sidebar.metric("Songs", f"{stats['songs']:,}")
        st.sidebar.metric("Chunks", f"{stats['chunks']:,}")

    # Only run startup once per session
    if 'initialized' not in st.session_state:
        print("===== Application Startup at", datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "=====\n")

        # Initialize generator
        try:
            st.info("Initializing generator... this may take a moment.")
            print("\n=== Initializing Generator ===")
            generator = LyricGenerator()
            st.session_state.generator = generator
            st.session_state.chat_history = []
            st.session_state.current_lyrics = None
            st.session_state.initialized = True

            # Compute DB stats for sidebar
            collection = generator.vector_store._collection
            all_meta = collection.get(include=["metadatas"])["metadatas"]
            artists = set()
            songs = set()
            for m in all_meta:
                artists.add(m.get("artist", ""))
                songs.add((m.get("artist", ""), m.get("song_title", "")))
            st.session_state.db_stats = {
                "chunks": len(all_meta),
                "songs": len(songs),
                "artists": len(artists),
            }

            print("Generator initialized successfully")
            st.rerun()
        except Exception as e:
            st.error(f"Error initializing generator: {str(e)}")
            print(f"Error: {str(e)}")
            import traceback
            traceback.print_exc()
            st.stop()

    # Check OpenAI API key
    if not Settings.OPENAI_API_KEY:
        st.error("OpenAI API key not found. Please set OPENAI_API_KEY.")
        st.stop()

    # Display chat history
    for message in st.session_state.chat_history:
        user_msg, assistant_msg = message
        with st.chat_message("user"):
            st.write(user_msg)
        with st.chat_message("assistant"):
            st.markdown(format_lyrics(assistant_msg))

    # Chat interface
    user_input = st.chat_input("Enter your prompt (ask for new lyrics or modify existing ones)...")

    if user_input:
        with st.chat_message("user"):
            st.write(user_input)

        with st.chat_message("assistant"):
            try:
                with st.spinner("Generating lyrics..."):
                    response = st.session_state.generator.generate_lyrics(
                        user_input,
                        st.session_state.chat_history
                    )

                # Store the response
                lyrics = response['answer']
                st.markdown(format_lyrics(lyrics))
                st.session_state.current_lyrics = lyrics

                # Display sources with content
                with st.expander("View Sources and Context"):
                    # Show top retrieved contexts with snippets
                    st.write("### Retrieved Contexts")
                    for detail in response["context_details"]:
                        st.write(
                            f"\n**{detail['artist']} - {detail['song']}**"
                        )
                        st.text(detail['content'])
                        st.write("---")

                    # Show all unique source songs from the chain
                    st.write("### All Sources Used")
                    seen_sources = set()
                    source_docs = response.get("source_documents", [])
                    for doc in source_docs:
                        source_key = (
                            doc.metadata.get('artist', 'Unknown'),
                            doc.metadata.get('song_title', 'Unknown')
                        )
                        if source_key not in seen_sources:
                            seen_sources.add(source_key)
                            st.write(f"- {source_key[0]} - {source_key[1]}")
                    st.write(f"\n*{len(seen_sources)} unique songs from {len({s[0] for s in seen_sources})} artists*")

                # Update chat history
                st.session_state.chat_history.append((user_input, lyrics))

            except Exception as e:
                st.error(f"Error generating lyrics: {str(e)}")

if __name__ == "__main__":
    main()