James Edmunds commited on
Commit
2969f34
·
1 Parent(s): 4abfe60

feat: improve HuggingFace dataset loading with better error handling

Browse files
Files changed (1) hide show
  1. src/generator/generator.py +42 -37
src/generator/generator.py CHANGED
@@ -4,13 +4,15 @@ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
4
  from langchain_chroma import Chroma
5
  from langchain.chains import ConversationalRetrievalChain
6
  from langchain.prompts import PromptTemplate
 
 
 
7
  from config.settings import Settings
8
 
9
 
10
  class LyricGenerator:
11
  def __init__(self):
12
  """Initialize the generator with embeddings"""
13
- # Use Settings to determine the correct path based on environment
14
  self.embeddings_dir = Settings.get_embeddings_path()
15
  self.embeddings = OpenAIEmbeddings()
16
  self.vector_store = None
@@ -21,50 +23,53 @@ class LyricGenerator:
21
 
22
  def _load_embeddings(self) -> None:
23
  """Load existing embeddings based on environment"""
24
- print(f"Loading embeddings from: {self.embeddings_dir}")
25
- print(f"Deployment mode: {Settings.DEPLOYMENT_MODE}")
26
- print(f"Directory exists: {self.embeddings_dir.exists()}")
27
-
28
- if not self.embeddings_dir.exists():
29
- if Settings.is_huggingface():
30
- # List contents of root directory in HF
31
- try:
32
- print("Contents of root directory:")
33
- import os
34
- print(os.listdir("/"))
35
- except Exception as e:
36
- print(f"Error listing root: {str(e)}")
 
 
 
 
 
37
 
38
- raise RuntimeError(
39
- "Embeddings not found in HF storage. "
40
- "Please ensure they are uploaded to the dataset."
41
- )
42
- else:
 
 
 
 
 
 
43
  raise RuntimeError(
44
  "Embeddings not found locally. "
45
  "Please run process_lyrics.py first."
46
  )
47
 
48
- try:
49
- # List contents of embeddings directory
50
  try:
51
- print("Contents of embeddings directory:")
52
- print(os.listdir(str(self.embeddings_dir)))
 
 
 
 
53
  except Exception as e:
54
- print(f"Error listing embeddings dir: {str(e)}")
55
-
56
- # Load vector store using environment-aware settings
57
- self.vector_store = Chroma(
58
- persist_directory=str(self.embeddings_dir),
59
- embedding_function=self.embeddings,
60
- collection_name="lyrics"
61
- )
62
-
63
- # Setup QA chain
64
- self._setup_qa_chain()
65
-
66
- except Exception as e:
67
- raise RuntimeError(f"Failed to load embeddings: {str(e)}")
68
 
69
  def _setup_qa_chain(self) -> None:
70
  """Initialize the QA chain for generating lyrics"""
 
4
  from langchain_chroma import Chroma
5
  from langchain.chains import ConversationalRetrievalChain
6
  from langchain.prompts import PromptTemplate
7
+ from datasets import load_dataset
8
+ import tempfile
9
+ import shutil
10
  from config.settings import Settings
11
 
12
 
13
  class LyricGenerator:
14
  def __init__(self):
15
  """Initialize the generator with embeddings"""
 
16
  self.embeddings_dir = Settings.get_embeddings_path()
17
  self.embeddings = OpenAIEmbeddings()
18
  self.vector_store = None
 
23
 
24
  def _load_embeddings(self) -> None:
25
  """Load existing embeddings based on environment"""
26
+ if Settings.is_huggingface():
27
+ try:
28
+ print(f"Loading dataset from HuggingFace: {Settings.HF_DATASET}")
29
+ # Create a temporary directory to store the dataset
30
+ with tempfile.TemporaryDirectory() as temp_dir:
31
+ # Load the dataset
32
+ dataset = load_dataset(
33
+ path=Settings.HF_DATASET,
34
+ split="train",
35
+ token=Settings.HF_TOKEN # Add token for private dataset access
36
+ )
37
+ print("Dataset loaded, downloading files...")
38
+ # Download the chroma files
39
+ dataset.download(temp_dir)
40
+ chroma_path = Path(temp_dir) / "chroma"
41
+ print(f"Looking for Chroma files in: {chroma_path}")
42
+ if not chroma_path.exists():
43
+ raise RuntimeError(f"Chroma directory not found in dataset at {chroma_path}")
44
 
45
+ # Use the downloaded files
46
+ self.vector_store = Chroma(
47
+ persist_directory=str(chroma_path),
48
+ embedding_function=self.embeddings,
49
+ collection_name="lyrics"
50
+ )
51
+ print("Successfully loaded vector store from dataset")
52
+ except Exception as e:
53
+ raise RuntimeError(f"Failed to load HuggingFace dataset: {str(e)}")
54
+ else:
55
+ if not self.embeddings_dir.exists():
56
  raise RuntimeError(
57
  "Embeddings not found locally. "
58
  "Please run process_lyrics.py first."
59
  )
60
 
 
 
61
  try:
62
+ # Load vector store using environment-aware settings
63
+ self.vector_store = Chroma(
64
+ persist_directory=str(self.embeddings_dir),
65
+ embedding_function=self.embeddings,
66
+ collection_name="lyrics"
67
+ )
68
  except Exception as e:
69
+ raise RuntimeError(f"Failed to load local embeddings: {str(e)}")
70
+
71
+ # Setup QA chain
72
+ self._setup_qa_chain()
 
 
 
 
 
 
 
 
 
 
73
 
74
  def _setup_qa_chain(self) -> None:
75
  """Initialize the QA chain for generating lyrics"""