MarketSync / modules /knowledge_base.py
hyeonjoo's picture
Initial project commit with LFS
9b1e3db
# modules/knowledge_base.py
import os
import streamlit as st
from pathlib import Path
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import traceback
import config
logger = config.get_logger(__name__)
@st.cache_resource
def _load_embedding_model():
"""
์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ์„ ๋ณ„๋„ ํ•จ์ˆ˜๋กœ ๋ถ„๋ฆฌํ•˜์—ฌ ์บ์‹ฑ (FAISS ๋กœ๋“œ ์‹œ ์žฌ์‚ฌ์šฉ)
"""
try:
logger.info("--- [Cache] HuggingFace ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ตœ์ดˆ ๋กœ๋”ฉ ์‹œ์ž‘ ---")
model_name = config.EMBEDDING_MODEL
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
logger.info(f"--- [Cache] HuggingFace ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ({model_name}) ๋กœ๋”ฉ ์„ฑ๊ณต ---")
return embeddings
except Exception as e:
logger.critical(f"--- [CRITICAL ERROR] ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ: {e} ---", exc_info=True)
st.error(f"์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ('{config.EMBEDDING_MODEL}') ๋กœ๋”ฉ ์ค‘ ์‹ฌ๊ฐํ•œ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")
return None
@st.cache_resource
def load_marketing_vectorstore():
"""
'๋งˆ์ผ€ํŒ… ์ „๋žต' FAISS Vector Store๋ฅผ ๋กœ๋“œํ•˜์—ฌ Retriever๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
"""
try:
logger.info("--- [Cache] '๋งˆ์ผ€ํŒ…' FAISS Vector Store ์ตœ์ดˆ ๋กœ๋”ฉ ์‹œ์ž‘ ---")
embeddings = _load_embedding_model()
if embeddings is None:
raise RuntimeError("์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜์—ฌ Retriever๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
vector_db_path = config.PATH_FAISS_MARKETING
if not vector_db_path.exists():
logger.critical(f"--- [CRITICAL ERROR] '๋งˆ์ผ€ํŒ…' Vector DB ๊ฒฝ๋กœ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {vector_db_path}")
st.error(f"'๋งˆ์ผ€ํŒ…' Vector DB ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. (๊ฒฝ๋กœ: {vector_db_path})")
return None
db = FAISS.load_local(
folder_path=str(vector_db_path),
embeddings=embeddings,
allow_dangerous_deserialization=True
)
retriever = db.as_retriever(search_kwargs={"k": 2})
logger.info("--- [Cache] '๋งˆ์ผ€ํŒ…' FAISS Vector Store ๋กœ๋”ฉ ์„ฑ๊ณต ---")
return retriever
except Exception as e:
logger.critical(f"--- [CRITICAL ERROR] '๋งˆ์ผ€ํŒ…' FAISS ๋กœ๋”ฉ ์‹คํŒจ: {e} ---", exc_info=True)
st.error(f"'๋งˆ์ผ€ํŒ…' Vector Store ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return None
@st.cache_resource
def load_festival_vectorstore():
"""
'์ถ•์ œ ์ •๋ณด' FAISS Vector Store๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
"""
try:
logger.info("--- [Cache] '์ถ•์ œ' FAISS Vector Store ์ตœ์ดˆ ๋กœ๋”ฉ ์‹œ์ž‘ ---")
embeddings = _load_embedding_model()
if embeddings is None:
raise RuntimeError("์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜์—ฌ '์ถ•์ œ' Vector Store๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
vector_db_path = config.PATH_FAISS_FESTIVAL
if not vector_db_path.exists():
logger.critical(f"--- [CRITICAL ERROR] '์ถ•์ œ' Vector DB ๊ฒฝ๋กœ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {vector_db_path}")
st.error(f"'์ถ•์ œ' Vector DB ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. (๊ฒฝ๋กœ: {vector_db_path})")
return None
db = FAISS.load_local(
folder_path=str(vector_db_path),
embeddings=embeddings,
allow_dangerous_deserialization=True
)
logger.info("--- [Cache] '์ถ•์ œ' FAISS Vector Store ๋กœ๋”ฉ ์„ฑ๊ณต ---")
return db
except Exception as e:
logger.critical(f"--- [CRITICAL ERROR] '์ถ•์ œ' FAISS ๋กœ๋”ฉ ์‹คํŒจ: {e} ---", exc_info=True)
st.error(f"'์ถ•์ œ' Vector Store ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return None