broadfield-dev commited on
Commit
2e63d3b
·
verified ·
1 Parent(s): 5ef463a

Update build_rag.py

Browse files
Files changed (1) hide show
  1. build_rag.py +18 -18
build_rag.py CHANGED
@@ -1,8 +1,10 @@
 
 
1
  import json
2
  import os
3
  import pandas as pd
4
  import torch
5
- import torch.nn.functional as F # Import the functional module
6
  from transformers import AutoTokenizer, AutoModel
7
  import chromadb
8
  import sys
@@ -13,14 +15,14 @@ import traceback
13
  # --- Configuration ---
14
  CHROMA_PATH = "chroma_db"
15
  COLLECTION_NAME = "bible_verses"
16
- # *** CHANGE 1: UPDATE THE MODEL NAME ***
17
- MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
18
- # *** CHANGE 2: UPDATE THE DATASET REPO NAME TO AVOID CONFUSION ***
19
- DATASET_REPO = "broadfield-dev/bible-chromadb-mpnet"
20
  STATUS_FILE = "build_status.log"
21
  JSON_DIRECTORY = 'bible_json'
22
  CHUNK_SIZE = 3
23
- EMBEDDING_BATCH_SIZE = 16 # Adjust based on available VRAM
24
  # (BOOK_ID_TO_NAME dictionary remains the same)
25
  BOOK_ID_TO_NAME = {
26
  1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
@@ -44,13 +46,13 @@ def update_status(message):
44
  with open(STATUS_FILE, "w") as f:
45
  f.write(message)
46
 
47
- # Mean Pooling Function - Take attention mask into account for correct averaging
48
  def mean_pooling(model_output, attention_mask):
49
- token_embeddings = model_output[0] #First element of model_output contains all token embeddings
50
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
51
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
52
 
53
- def process_bible_json_files(directory_path: str, chunk_size: int) -> pd.DataFrame:
54
  # (This function is unchanged)
55
  all_verses = []
56
  if not os.path.exists(directory_path) or not os.listdir(directory_path):
@@ -92,36 +94,35 @@ def main():
92
 
93
  collection = client.create_collection(
94
  name=COLLECTION_NAME,
95
- metadata={"hnsw:space": "cosine"} # Use cosine distance
96
  )
97
 
98
  update_status(f"IN_PROGRESS: Step 3/5 - Loading embedding model '{MODEL_NAME}'...")
99
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
100
  model = AutoModel.from_pretrained(MODEL_NAME, device_map="auto")
101
 
102
- update_status("IN_PROGRESS: Step 4/5 - Generating and NORMALIZING embeddings...")
103
  for i in tqdm(range(0, len(bible_chunks_df), EMBEDDING_BATCH_SIZE), desc="Embedding Chunks"):
104
  batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
105
  texts = batch_df['text'].tolist()
106
 
107
- # *** CHANGE 3: USE THE CORRECT POOLING STRATEGY FOR SBERT MODELS ***
108
  encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(model.device)
109
  with torch.no_grad():
110
  model_output = model(**encoded_input)
111
 
112
- # Perform pooling and normalization
113
- sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
114
- normalized_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
 
115
 
116
  collection.add(
117
  ids=[str(j) for j in range(i, i + len(batch_df))],
118
- embeddings=normalized_embeddings.cpu().tolist(), # Convert to list
119
  documents=texts,
120
  metadatas=batch_df[['reference', 'version']].to_dict('records')
121
  )
122
 
123
  update_status(f"IN_PROGRESS: Step 5/5 - Pushing database to Hugging Face Hub '{DATASET_REPO}'...")
124
- # (This part is unchanged)
125
  create_repo(repo_id=DATASET_REPO, repo_type="dataset", exist_ok=True)
126
  api = HfApi()
127
  api.upload_folder(
@@ -136,7 +137,6 @@ if __name__ == "__main__":
136
  try:
137
  main()
138
  except Exception as e:
139
- # (Error handling is unchanged)
140
  error_message = traceback.format_exc()
141
  if "401" in str(e) or "Unauthorized" in str(e):
142
  update_status("FAILED: Hugging Face authentication error. Ensure your HF_TOKEN secret has WRITE permissions.")
 
1
+ # build_rag.py (Updated for a model with pre-normalized embeddings)
2
+
3
  import json
4
  import os
5
  import pandas as pd
6
  import torch
7
+ import torch.nn.functional as F
8
  from transformers import AutoTokenizer, AutoModel
9
  import chromadb
10
  import sys
 
15
  # --- Configuration ---
16
  CHROMA_PATH = "chroma_db"
17
  COLLECTION_NAME = "bible_verses"
18
+ # *** CHANGE 1: USE A MODEL WITH NORMALIZED EMBEDDINGS ***
19
+ MODEL_NAME = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
20
+ # *** CHANGE 2: USE A NEW REPO FOR THE NEW DATABASE ***
21
+ DATASET_REPO = "broadfield-dev/bible-chromadb-multi-qa-mpnet"
22
  STATUS_FILE = "build_status.log"
23
  JSON_DIRECTORY = 'bible_json'
24
  CHUNK_SIZE = 3
25
+ EMBEDDING_BATCH_SIZE = 16
26
  # (BOOK_ID_TO_NAME dictionary remains the same)
27
  BOOK_ID_TO_NAME = {
28
  1: "Genesis", 2: "Exodus", 3: "Leviticus", 4: "Numbers", 5: "Deuteronomy",
 
46
  with open(STATUS_FILE, "w") as f:
47
  f.write(message)
48
 
49
+ # Mean Pooling Function - Crucial for sentence-transformer models
50
  def mean_pooling(model_output, attention_mask):
51
+ token_embeddings = model_output[0]
52
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
53
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
54
 
55
+ def process_bible_json_files(directory_path: str, chunk_size: int):
56
  # (This function is unchanged)
57
  all_verses = []
58
  if not os.path.exists(directory_path) or not os.listdir(directory_path):
 
94
 
95
  collection = client.create_collection(
96
  name=COLLECTION_NAME,
97
+ metadata={"hnsw:space": "cosine"}
98
  )
99
 
100
  update_status(f"IN_PROGRESS: Step 3/5 - Loading embedding model '{MODEL_NAME}'...")
101
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
102
  model = AutoModel.from_pretrained(MODEL_NAME, device_map="auto")
103
 
104
+ update_status("IN_PROGRESS: Step 4/5 - Generating embeddings (no normalization needed)...")
105
  for i in tqdm(range(0, len(bible_chunks_df), EMBEDDING_BATCH_SIZE), desc="Embedding Chunks"):
106
  batch_df = bible_chunks_df.iloc[i:i+EMBEDDING_BATCH_SIZE]
107
  texts = batch_df['text'].tolist()
108
 
 
109
  encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(model.device)
110
  with torch.no_grad():
111
  model_output = model(**encoded_input)
112
 
113
+ embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
114
+
115
+ # *** REMOVED: NO LONGER NEED TO NORMALIZE THE EMBEDDINGS ***
116
+ # embeddings = F.normalize(embeddings, p=2, dim=1)
117
 
118
  collection.add(
119
  ids=[str(j) for j in range(i, i + len(batch_df))],
120
+ embeddings=embeddings.cpu().tolist(),
121
  documents=texts,
122
  metadatas=batch_df[['reference', 'version']].to_dict('records')
123
  )
124
 
125
  update_status(f"IN_PROGRESS: Step 5/5 - Pushing database to Hugging Face Hub '{DATASET_REPO}'...")
 
126
  create_repo(repo_id=DATASET_REPO, repo_type="dataset", exist_ok=True)
127
  api = HfApi()
128
  api.upload_folder(
 
137
  try:
138
  main()
139
  except Exception as e:
 
140
  error_message = traceback.format_exc()
141
  if "401" in str(e) or "Unauthorized" in str(e):
142
  update_status("FAILED: Hugging Face authentication error. Ensure your HF_TOKEN secret has WRITE permissions.")