Spaces:
Sleeping
Sleeping
File size: 3,932 Bytes
9b1e3db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# modules/knowledge_base.py
import os
import streamlit as st
from pathlib import Path
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import traceback
import config
logger = config.get_logger(__name__)
@st.cache_resource
def _load_embedding_model():
"""
์๋ฒ ๋ฉ ๋ชจ๋ธ์ ๋ณ๋ ํจ์๋ก ๋ถ๋ฆฌํ์ฌ ์บ์ฑ (FAISS ๋ก๋ ์ ์ฌ์ฌ์ฉ)
"""
try:
logger.info("--- [Cache] HuggingFace ์๋ฒ ๋ฉ ๋ชจ๋ธ ์ต์ด ๋ก๋ฉ ์์ ---")
model_name = config.EMBEDDING_MODEL
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
logger.info(f"--- [Cache] HuggingFace ์๋ฒ ๋ฉ ๋ชจ๋ธ ({model_name}) ๋ก๋ฉ ์ฑ๊ณต ---")
return embeddings
except Exception as e:
logger.critical(f"--- [CRITICAL ERROR] ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e} ---", exc_info=True)
st.error(f"์๋ฒ ๋ฉ ๋ชจ๋ธ('{config.EMBEDDING_MODEL}') ๋ก๋ฉ ์ค ์ฌ๊ฐํ ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}")
return None
@st.cache_resource
def load_marketing_vectorstore():
"""
'๋ง์ผํ
์ ๋ต' FAISS Vector Store๋ฅผ ๋ก๋ํ์ฌ Retriever๋ฅผ ์์ฑํฉ๋๋ค.
"""
try:
logger.info("--- [Cache] '๋ง์ผํ
' FAISS Vector Store ์ต์ด ๋ก๋ฉ ์์ ---")
embeddings = _load_embedding_model()
if embeddings is None:
raise RuntimeError("์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ฉ์ ์คํจํ์ฌ Retriever๋ฅผ ์์ฑํ ์ ์์ต๋๋ค.")
vector_db_path = config.PATH_FAISS_MARKETING
if not vector_db_path.exists():
logger.critical(f"--- [CRITICAL ERROR] '๋ง์ผํ
' Vector DB ๊ฒฝ๋ก๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค: {vector_db_path}")
st.error(f"'๋ง์ผํ
' Vector DB ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. (๊ฒฝ๋ก: {vector_db_path})")
return None
db = FAISS.load_local(
folder_path=str(vector_db_path),
embeddings=embeddings,
allow_dangerous_deserialization=True
)
retriever = db.as_retriever(search_kwargs={"k": 2})
logger.info("--- [Cache] '๋ง์ผํ
' FAISS Vector Store ๋ก๋ฉ ์ฑ๊ณต ---")
return retriever
except Exception as e:
logger.critical(f"--- [CRITICAL ERROR] '๋ง์ผํ
' FAISS ๋ก๋ฉ ์คํจ: {e} ---", exc_info=True)
st.error(f"'๋ง์ผํ
' Vector Store ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {e}")
return None
@st.cache_resource
def load_festival_vectorstore():
"""
'์ถ์ ์ ๋ณด' FAISS Vector Store๋ฅผ ๋ก๋ํฉ๋๋ค.
"""
try:
logger.info("--- [Cache] '์ถ์ ' FAISS Vector Store ์ต์ด ๋ก๋ฉ ์์ ---")
embeddings = _load_embedding_model()
if embeddings is None:
raise RuntimeError("์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ฉ์ ์คํจํ์ฌ '์ถ์ ' Vector Store๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค.")
vector_db_path = config.PATH_FAISS_FESTIVAL
if not vector_db_path.exists():
logger.critical(f"--- [CRITICAL ERROR] '์ถ์ ' Vector DB ๊ฒฝ๋ก๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค: {vector_db_path}")
st.error(f"'์ถ์ ' Vector DB ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. (๊ฒฝ๋ก: {vector_db_path})")
return None
db = FAISS.load_local(
folder_path=str(vector_db_path),
embeddings=embeddings,
allow_dangerous_deserialization=True
)
logger.info("--- [Cache] '์ถ์ ' FAISS Vector Store ๋ก๋ฉ ์ฑ๊ณต ---")
return db
except Exception as e:
logger.critical(f"--- [CRITICAL ERROR] '์ถ์ ' FAISS ๋ก๋ฉ ์คํจ: {e} ---", exc_info=True)
st.error(f"'์ถ์ ' Vector Store ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {e}")
return None |