James Edmunds commited on
Commit ·
e2b66bb
1
Parent(s): 1af442e
revert: remove unnecessary connection handling changes
Browse files- config/settings.py +0 -5
- requirements.txt +0 -1
- src/generator/generator.py +10 -42
config/settings.py
CHANGED
|
@@ -13,11 +13,6 @@ class Settings:
|
|
| 13 |
|
| 14 |
# API Keys
|
| 15 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 16 |
-
if not OPENAI_API_KEY:
|
| 17 |
-
raise ValueError(
|
| 18 |
-
"OpenAI API key not found. Please set OPENAI_API_KEY "
|
| 19 |
-
"environment variable."
|
| 20 |
-
)
|
| 21 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 22 |
|
| 23 |
# Directory Paths
|
|
|
|
| 13 |
|
| 14 |
# API Keys
|
| 15 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 17 |
|
| 18 |
# Directory Paths
|
requirements.txt
CHANGED
|
@@ -6,7 +6,6 @@ chromadb>=0.4.22
|
|
| 6 |
streamlit==1.41.0
|
| 7 |
python-dotenv==1.0.1
|
| 8 |
huggingface-hub>=0.19.4
|
| 9 |
-
datasets>=2.16.0
|
| 10 |
pytest>=7.4.0
|
| 11 |
black>=23.0.0
|
| 12 |
flake8>=6.0.0
|
|
|
|
| 6 |
streamlit==1.41.0
|
| 7 |
python-dotenv==1.0.1
|
| 8 |
huggingface-hub>=0.19.4
|
|
|
|
| 9 |
pytest>=7.4.0
|
| 10 |
black>=23.0.0
|
| 11 |
flake8>=6.0.0
|
src/generator/generator.py
CHANGED
|
@@ -12,10 +12,7 @@ class LyricGenerator:
|
|
| 12 |
def __init__(self):
|
| 13 |
"""Initialize the generator with embeddings"""
|
| 14 |
self.embeddings_dir = Settings.get_embeddings_path()
|
| 15 |
-
self.embeddings = OpenAIEmbeddings(
|
| 16 |
-
request_timeout=60, # Increase timeout for embeddings
|
| 17 |
-
max_retries=3 # Add retries for robustness
|
| 18 |
-
)
|
| 19 |
self.vector_store = None
|
| 20 |
self.qa_chain = None
|
| 21 |
|
|
@@ -163,16 +160,14 @@ class LyricGenerator:
|
|
| 163 |
template=system_template
|
| 164 |
)
|
| 165 |
|
| 166 |
-
# Initialize language model
|
| 167 |
llm = ChatOpenAI(
|
| 168 |
temperature=0.9,
|
| 169 |
model_name="gpt-4",
|
| 170 |
max_tokens=1000,
|
| 171 |
top_p=0.95,
|
| 172 |
presence_penalty=0.0,
|
| 173 |
-
frequency_penalty=0.1
|
| 174 |
-
http_client=None, # Let OpenAI handle proxy settings
|
| 175 |
-
request_timeout=60 # Increase timeout
|
| 176 |
)
|
| 177 |
|
| 178 |
# Create QA chain
|
|
@@ -223,41 +218,15 @@ class LyricGenerator:
|
|
| 223 |
'artist': doc.metadata['artist'],
|
| 224 |
'song': doc.metadata['song_title'],
|
| 225 |
'similarity': similarity,
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
try:
|
| 230 |
-
# Test OpenAI connection first
|
| 231 |
-
print("Testing OpenAI connection...")
|
| 232 |
-
test_response = self.embeddings.embed_query("test")
|
| 233 |
-
print("OpenAI connection successful")
|
| 234 |
-
|
| 235 |
-
# Generate response using invoke
|
| 236 |
-
print("Generating lyrics with QA chain...")
|
| 237 |
-
response = self.qa_chain.invoke({
|
| 238 |
-
"question": prompt,
|
| 239 |
-
"chat_history": chat_history
|
| 240 |
})
|
| 241 |
-
print("Successfully generated lyrics")
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
"OpenAI API authentication failed. Please check your API key."
|
| 249 |
-
)
|
| 250 |
-
elif "rate" in error_msg.lower():
|
| 251 |
-
raise RuntimeError(
|
| 252 |
-
"OpenAI API rate limit exceeded. Please try again in a moment."
|
| 253 |
-
)
|
| 254 |
-
elif "connect" in error_msg.lower():
|
| 255 |
-
raise RuntimeError(
|
| 256 |
-
"Connection to OpenAI failed. Please check your internet "
|
| 257 |
-
"connection and try again."
|
| 258 |
-
)
|
| 259 |
-
else:
|
| 260 |
-
raise RuntimeError(f"OpenAI API error: {error_msg}")
|
| 261 |
|
| 262 |
# Add detailed context to response
|
| 263 |
response["source_documents_with_scores"] = docs_and_scores
|
|
@@ -266,5 +235,4 @@ class LyricGenerator:
|
|
| 266 |
return response
|
| 267 |
|
| 268 |
except Exception as e:
|
| 269 |
-
print(f"Error in generate_lyrics: {str(e)}")
|
| 270 |
raise RuntimeError(f"Failed to generate lyrics: {str(e)}")
|
|
|
|
| 12 |
def __init__(self):
|
| 13 |
"""Initialize the generator with embeddings"""
|
| 14 |
self.embeddings_dir = Settings.get_embeddings_path()
|
| 15 |
+
self.embeddings = OpenAIEmbeddings()
|
|
|
|
|
|
|
|
|
|
| 16 |
self.vector_store = None
|
| 17 |
self.qa_chain = None
|
| 18 |
|
|
|
|
| 160 |
template=system_template
|
| 161 |
)
|
| 162 |
|
| 163 |
+
# Initialize language model
|
| 164 |
llm = ChatOpenAI(
|
| 165 |
temperature=0.9,
|
| 166 |
model_name="gpt-4",
|
| 167 |
max_tokens=1000,
|
| 168 |
top_p=0.95,
|
| 169 |
presence_penalty=0.0,
|
| 170 |
+
frequency_penalty=0.1
|
|
|
|
|
|
|
| 171 |
)
|
| 172 |
|
| 173 |
# Create QA chain
|
|
|
|
| 218 |
'artist': doc.metadata['artist'],
|
| 219 |
'song': doc.metadata['song_title'],
|
| 220 |
'similarity': similarity,
|
| 221 |
+
# First 200 chars
|
| 222 |
+
'content': doc.page_content[:200] + "..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
})
|
|
|
|
| 224 |
|
| 225 |
+
# Generate response using invoke
|
| 226 |
+
response = self.qa_chain.invoke({
|
| 227 |
+
"question": prompt,
|
| 228 |
+
"chat_history": chat_history
|
| 229 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
# Add detailed context to response
|
| 232 |
response["source_documents_with_scores"] = docs_and_scores
|
|
|
|
| 235 |
return response
|
| 236 |
|
| 237 |
except Exception as e:
|
|
|
|
| 238 |
raise RuntimeError(f"Failed to generate lyrics: {str(e)}")
|