RFP_summary_chatbot / src /embedding /rag_data_processing.py
Dongjin1203's picture
Initial commit for HF Spaces deployment
4739096
import pandas as pd
from langchain_chroma import Chroma
from langchain_openai.embeddings import OpenAIEmbeddings
import os
from tqdm import tqdm
import time
from src.utils.config import RAGConfig
class DataValidator:
"""๋ฐ์ดํ„ฐ ๊ฒ€์ฆ ๋ฐ ์ •์ œ"""
def __init__(self, config: RAGConfig):
self.config = config
def validate_and_clean(self, df: pd.DataFrame) -> pd.DataFrame:
"""์ „์ฒด ๊ฒ€์ฆ ๋ฐ ์ •์ œ ํŒŒ์ดํ”„๋ผ์ธ"""
df = self._check_required_columns(df)
df = self._remove_duplicates(df)
df = self._remove_nan(df)
df = self._filter_by_length(df)
df = self._clean_metadata(df)
return df
def _check_required_columns(self, df: pd.DataFrame) -> pd.DataFrame:
"""ํ•„์ˆ˜ ์ปฌ๋Ÿผ ํ™•์ธ"""
required = ['chunk_content', 'chunk_id']
missing = [col for col in required if col not in df.columns]
if missing:
raise ValueError(f"ํ•„์ˆ˜ ์ปฌ๋Ÿผ ๋ˆ„๋ฝ: {missing}")
return df
def _remove_duplicates(self, df: pd.DataFrame) -> pd.DataFrame:
"""์ค‘๋ณต ID ์ œ๊ฑฐ"""
return df.drop_duplicates(subset=['chunk_id'], keep='first')
def _remove_nan(self, df: pd.DataFrame) -> pd.DataFrame:
"""NaN ๊ฐ’ ์ œ๊ฑฐ"""
return df.dropna(subset=['chunk_content', 'chunk_id'])
def _filter_by_length(self, df: pd.DataFrame) -> pd.DataFrame:
"""๊ธธ์ด ๊ธฐ์ค€ ํ•„ํ„ฐ๋ง"""
df['_temp_length'] = df['chunk_content'].str.len()
df = df[
(df['_temp_length'] >= self.config.MIN_CHUNK_LENGTH) &
(df['_temp_length'] <= self.config.MAX_CHUNK_LENGTH)
]
return df.drop(columns=['_temp_length'])
def _clean_metadata(self, df: pd.DataFrame) -> pd.DataFrame:
"""๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ •์ œ"""
# NaN์„ ๋นˆ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
df = df.fillna('')
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ปฌ๋Ÿผ์˜ ํƒ€์ž…์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
metadata_cols = [col for col in df.columns
if col not in ['chunk_content', 'chunk_id']]
for col in metadata_cols:
df[col] = df[col].astype(str)
return df
class ChromaDBBuilder:
"""ChromaDB ๋ฒกํ„ฐ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๊ตฌ์ถ•"""
def __init__(self, config: RAGConfig):
self.config = config
self.vectorstore = None
self.embeddings = None
self._initialize_embeddings()
def _initialize_embeddings(self):
"""์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”"""
os.environ["OPENAI_API_KEY"] = self.config.OPENAI_API_KEY
self.embeddings = OpenAIEmbeddings(
model=self.config.EMBEDDING_MODEL_NAME
)
def build_from_dataframe(self, df: pd.DataFrame):
"""DataFrame์œผ๋กœ๋ถ€ํ„ฐ ๋ฒกํ„ฐ DB ๊ตฌ์ถ•"""
documents, ids, metadatas = self._prepare_data(df)
self._validate_data_consistency(documents, ids, metadatas)
self._create_vectorstore()
self._add_documents_in_batches(documents, ids, metadatas)
return self.vectorstore
def _prepare_data(self, df: pd.DataFrame):
"""ChromaDB์šฉ ๋ฐ์ดํ„ฐ ์ค€๋น„"""
documents = df['chunk_content'].tolist()
ids = df['chunk_id'].tolist()
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ถ”์ถœ
metadata_cols = [col for col in df.columns
if col not in ['chunk_content', 'chunk_id']]
metadatas = []
for _, row in df.iterrows():
metadata = {
col: row[col]
for col in metadata_cols
if row[col] and row[col] != 'nan' and row[col] != ''
}
metadatas.append(metadata)
return documents, ids, metadatas
def _validate_data_consistency(self, documents, ids, metadatas):
"""๋ฐ์ดํ„ฐ ์ผ๊ด€์„ฑ ๊ฒ€์ฆ"""
if not (len(documents) == len(ids) == len(metadatas)):
raise ValueError("๋ฐ์ดํ„ฐ ๊ธธ์ด ๋ถˆ์ผ์น˜")
def _create_vectorstore(self):
"""๋นˆ ๋ฒกํ„ฐ์Šคํ† ์–ด ์ƒ์„ฑ"""
self.vectorstore = Chroma(
embedding_function=self.embeddings,
persist_directory=self.config.DB_DIRECTORY,
collection_name=self.config.COLLECTION_NAME
)
def _add_documents_in_batches(self, documents, ids, metadatas):
"""๋ฐฐ์น˜ ์ฒ˜๋ฆฌ๋กœ ๋ฌธ์„œ ์ถ”๊ฐ€"""
batch_size = self.config.BATCH_SIZE
total_batches = (len(documents) + batch_size - 1) // batch_size
for i in tqdm(range(0, len(documents), batch_size),
desc="์ž„๋ฒ ๋”ฉ ๋ฐ ์ €์žฅ",
total=total_batches):
batch_docs = documents[i:i + batch_size]
batch_ids = ids[i:i + batch_size]
batch_metas = metadatas[i:i + batch_size]
self._add_batch_with_retry(batch_docs, batch_ids, batch_metas)
time.sleep(1)
def _add_batch_with_retry(self, docs, ids, metas):
"""๋ฐฐ์น˜ ์ถ”๊ฐ€ (์‹คํŒจ ์‹œ ์žฌ์‹œ๋„)"""
batch_tokens = sum(len(doc) for doc in docs) / 4
if batch_tokens > self.config.MAX_TOKENS_PER_BATCH:
smaller_size = len(docs) // 2
for j in range(0, len(docs), smaller_size):
self.vectorstore.add_texts(
texts=docs[j:j + smaller_size],
metadatas=metas[j:j + smaller_size],
ids=ids[j:j + smaller_size]
)
time.sleep(0.5)
else:
try:
self.vectorstore.add_texts(
texts=docs,
metadatas=metas,
ids=ids
)
except Exception as e:
for j in range(0, len(docs), 10):
self.vectorstore.add_texts(
texts=docs[j:j + 10],
metadatas=metas[j:j + 10],
ids=ids[j:j + 10]
)
time.sleep(0.5)
def get_collection_count(self):
"""์ €์žฅ๋œ ๋ฌธ์„œ ์ˆ˜ ๋ฐ˜ํ™˜"""
if self.vectorstore:
return self.vectorstore._collection.count()
return 0
def search(self, query: str, k: int = 5):
"""๊ฒ€์ƒ‰ ์ˆ˜ํ–‰"""
if not self.vectorstore:
raise ValueError("๋ฒกํ„ฐ์Šคํ† ์–ด๊ฐ€ ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค")
return self.vectorstore.similarity_search_with_score(query, k=k)
class RAGVectorDBPipeline:
"""์ „์ฒด RAG Vector DB ๊ตฌ์ถ• ํŒŒ์ดํ”„๋ผ์ธ"""
def __init__(self, config: RAGConfig = None):
self.config = config or RAGConfig()
self.validator = DataValidator(self.config)
self.builder = ChromaDBBuilder(self.config)
def build(self):
"""์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ ์‹คํ–‰"""
# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
df = pd.read_csv(self.config.RAG_INPUT_PATH)
print(f"์›๋ณธ ๋ฐ์ดํ„ฐ: {len(df)}๊ฐœ ์ฒญํฌ")
# ๋ฐ์ดํ„ฐ ๊ฒ€์ฆ ๋ฐ ์ •์ œ
df_cleaned = self.validator.validate_and_clean(df)
print(f"์ •์ œ ํ›„ ๋ฐ์ดํ„ฐ: {len(df_cleaned)}๊ฐœ ์ฒญํฌ")
# ๋ฒกํ„ฐ DB ๊ตฌ์ถ•
vectorstore = self.builder.build_from_dataframe(df_cleaned)
# ๊ฒฐ๊ณผ ํ™•์ธ
count = self.builder.get_collection_count()
print(f"โœ… ChromaDB ์ €์žฅ ์™„๋ฃŒ: {count}๊ฐœ ๋ฌธ์„œ")
print(f"์ €์žฅ ์œ„์น˜: {self.config.DB_DIRECTORY}")
return vectorstore
def test_search(self, query: str = "ํ•™์‚ฌ ์ •๋ณด ์‹œ์Šคํ…œ", k: int = 3):
"""๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ"""
results = self.builder.search(query, k=k)
print(f"\nํ…Œ์ŠคํŠธ ์ฟผ๋ฆฌ: '{query}'")
print(f"๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ: {len(results)}๊ฐœ\n")
for i, (doc, score) in enumerate(results, 1):
print(f"[{i}] ๊ฑฐ๋ฆฌ: {score:.4f}")
print(f"๋‚ด์šฉ: {doc.page_content[:100]}...")
print(f"๋ฉ”ํƒ€๋ฐ์ดํ„ฐ: {doc.metadata}\n")
return results