Spaces:
Sleeping
Sleeping
Update src/rag_engine.py
Browse files- src/rag_engine.py +30 -4
src/rag_engine.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import shutil
|
| 3 |
import logging
|
| 4 |
from typing import List, Tuple, Optional
|
|
|
|
| 5 |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader
|
| 6 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 7 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
@@ -27,15 +28,40 @@ except Exception as e:
|
|
| 27 |
|
| 28 |
def get_embedding_func(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
| 29 |
try:
|
| 30 |
-
|
|
|
|
| 31 |
if not os.getenv("OPENAI_API_KEY"): raise ValueError("OpenAI API Key not found.")
|
| 32 |
return OpenAIEmbeddings(model=model_name)
|
| 33 |
-
|
|
|
|
| 34 |
elif "navy-custom-models" in model_name:
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
else:
|
| 38 |
return HuggingFaceEmbeddings(model_name=model_name)
|
|
|
|
| 39 |
except Exception as e:
|
| 40 |
logger.error(f"Failed to load embedding model '{model_name}': {e}")
|
| 41 |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
| 2 |
import shutil
|
| 3 |
import logging
|
| 4 |
from typing import List, Tuple, Optional
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader
|
| 7 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 8 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
| 28 |
|
| 29 |
def get_embedding_func(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
|
| 30 |
try:
|
| 31 |
+
# CHECK 1: OpenAI
|
| 32 |
+
if "openai" in model_name.lower() or "text-embedding" in model_name.lower():
|
| 33 |
if not os.getenv("OPENAI_API_KEY"): raise ValueError("OpenAI API Key not found.")
|
| 34 |
return OpenAIEmbeddings(model=model_name)
|
| 35 |
+
|
| 36 |
+
# CHECK 2: YOUR CUSTOM FINE-TUNE (Updated for Subfolders)
|
| 37 |
elif "navy-custom-models" in model_name:
|
| 38 |
+
logger.info(f"Downloading custom model from: {model_name}")
|
| 39 |
+
|
| 40 |
+
# 1. Parse the repo and folder from your string
|
| 41 |
+
# Input: "NavyDevilDoc/navy-custom-models/bge-finetuned"
|
| 42 |
+
parts = model_name.split("/")
|
| 43 |
+
# Repo ID is the first two parts: "NavyDevilDoc/navy-custom-models"
|
| 44 |
+
repo_id = f"{parts[0]}/{parts[1]}"
|
| 45 |
+
# Folder is the rest: "bge-finetuned"
|
| 46 |
+
folder_name = parts[2]
|
| 47 |
+
|
| 48 |
+
# 2. Download ONLY that folder
|
| 49 |
+
storage_path = snapshot_download(
|
| 50 |
+
repo_id=repo_id,
|
| 51 |
+
repo_type="model",
|
| 52 |
+
allow_patterns=f"{folder_name}/*"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# 3. Construct the local path to the inner folder
|
| 56 |
+
local_model_path = os.path.join(storage_path, folder_name)
|
| 57 |
+
|
| 58 |
+
# 4. Load from the local path
|
| 59 |
+
return HuggingFaceEmbeddings(model_name=local_model_path)
|
| 60 |
+
|
| 61 |
+
# CHECK 3: Standard Public Models
|
| 62 |
else:
|
| 63 |
return HuggingFaceEmbeddings(model_name=model_name)
|
| 64 |
+
|
| 65 |
except Exception as e:
|
| 66 |
logger.error(f"Failed to load embedding model '{model_name}': {e}")
|
| 67 |
return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|