NavyDevilDoc commited on
Commit
275f2bc
·
verified ·
1 Parent(s): 553edb3

Update src/rag_engine.py

Browse files
Files changed (1) hide show
  1. src/rag_engine.py +30 -4
src/rag_engine.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import shutil
3
  import logging
4
  from typing import List, Tuple, Optional
 
5
  from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -27,15 +28,40 @@ except Exception as e:
27
 
28
  def get_embedding_func(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
29
  try:
30
- if "openai" in model_name.lower():
 
31
  if not os.getenv("OPENAI_API_KEY"): raise ValueError("OpenAI API Key not found.")
32
  return OpenAIEmbeddings(model=model_name)
33
-
 
34
  elif "navy-custom-models" in model_name:
35
- return HuggingFaceEmbeddings(model_name=model_name)
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  else:
38
  return HuggingFaceEmbeddings(model_name=model_name)
 
39
  except Exception as e:
40
  logger.error(f"Failed to load embedding model '{model_name}': {e}")
41
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
2
  import shutil
3
  import logging
4
  from typing import List, Tuple, Optional
5
+ from huggingface_hub import snapshot_download
6
  from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_huggingface import HuggingFaceEmbeddings
 
28
 
29
  def get_embedding_func(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
30
  try:
31
+ # CHECK 1: OpenAI
32
+ if "openai" in model_name.lower() or "text-embedding" in model_name.lower():
33
  if not os.getenv("OPENAI_API_KEY"): raise ValueError("OpenAI API Key not found.")
34
  return OpenAIEmbeddings(model=model_name)
35
+
36
+ # CHECK 2: YOUR CUSTOM FINE-TUNE (Updated for Subfolders)
37
  elif "navy-custom-models" in model_name:
38
+ logger.info(f"Downloading custom model from: {model_name}")
39
+
40
+ # 1. Parse the repo and folder from your string
41
+ # Input: "NavyDevilDoc/navy-custom-models/bge-finetuned"
42
+ parts = model_name.split("/")
43
+ # Repo ID is the first two parts: "NavyDevilDoc/navy-custom-models"
44
+ repo_id = f"{parts[0]}/{parts[1]}"
45
+ # Folder is the rest: "bge-finetuned"
46
+ folder_name = parts[2]
47
+
48
+ # 2. Download ONLY that folder
49
+ storage_path = snapshot_download(
50
+ repo_id=repo_id,
51
+ repo_type="model",
52
+ allow_patterns=f"{folder_name}/*"
53
+ )
54
+
55
+ # 3. Construct the local path to the inner folder
56
+ local_model_path = os.path.join(storage_path, folder_name)
57
+
58
+ # 4. Load from the local path
59
+ return HuggingFaceEmbeddings(model_name=local_model_path)
60
+
61
+ # CHECK 3: Standard Public Models
62
  else:
63
  return HuggingFaceEmbeddings(model_name=model_name)
64
+
65
  except Exception as e:
66
  logger.error(f"Failed to load embedding model '{model_name}': {e}")
67
  return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")