Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from typing import List, Dict, Tuple | |
| from dataclasses import dataclass | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain_core.documents import Document | |
| from document_processor import DocumentProcessor | |
| from vector_store import VectorStore | |
| from config import Config | |
| class ChatResponse: | |
| """챗봇 응답 결과 클래스""" | |
| answer: str | |
| sources: List[Dict] | |
| confidence: float | |
| response_time: float | |
| class RAGChatbot: | |
| """소방 복무관리 RAG 챗봇""" | |
| def __init__(self): | |
| self.document_processor = DocumentProcessor( | |
| chunk_size=Config.CHUNK_SIZE, | |
| chunk_overlap=Config.CHUNK_OVERLAP | |
| ) | |
| self.vector_store = VectorStore() | |
| self.llm = None | |
| self.llm_tokenizer = None | |
| self.is_initialized = False | |
| def initialize(self, docs_folder: str = None, force_rebuild: bool = False): | |
| """챗봇 초기화""" | |
| print("🤖 소방 복무관리 RAG 챗봇 초기화 중...") | |
| # 1. 문서 로드 및 처리 | |
| docs_folder = docs_folder or Config.DOCS_FOLDER | |
| documents = self._load_documents(docs_folder) | |
| if not documents: | |
| print("❌ 처리할 문서가 없습니다. documents 폴더에 파일을 넣어주세요.") | |
| return False | |
| # 2. 벡터 데이터베이스 구축 | |
| success = self.vector_store.rebuild_if_needed(documents, force_rebuild) | |
| if not success: | |
| print("❌ 벡터 데이터베이스 구축 실패") | |
| return False | |
| # 3. LLM 모델 로드 (선택적 - 메모리 부족 시 스킵) | |
| try: | |
| self._load_llm() | |
| except Exception as e: | |
| print(f"⚠️ LLM 모델 로드 실패: {str(e)}") | |
| print("📝 템플릿 기반 응답 모드로 동작합니다.") | |
| self.is_initialized = True | |
| print("✅ RAG 챗봇 초기화 완료") | |
| return True | |
| def _load_documents(self, docs_folder: str) -> List[Document]: | |
| """문서 로드 및 처리""" | |
| if not os.path.exists(docs_folder): | |
| print(f"⚠️ 문서 폴더가 존재하지 않습니다: {docs_folder}") | |
| return [] | |
| print(f"📂 문서 폴더: {docs_folder}") | |
| raw_documents = self.document_processor.load_documents_from_folder(docs_folder) | |
| processed_documents = self.document_processor.process_documents(raw_documents) | |
| print(f"✅ 총 {len(processed_documents)}개 문서 청크 생성 완료") | |
| return processed_documents | |
| def _load_llm(self): | |
| """LLM 모델 로드""" | |
| print(f"🧠 LLM 모델 로드: {Config.LLM_MODEL}") | |
| try: | |
| self.llm_tokenizer = AutoTokenizer.from_pretrained( | |
| Config.LLM_MODEL, | |
| trust_remote_code=True | |
| ) | |
| # 패딩 토큰 설정 | |
| if self.llm_tokenizer.pad_token is None: | |
| self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token | |
| self.llm = AutoModelForCausalLM.from_pretrained( | |
| Config.LLM_MODEL, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print("✅ LLM 모델 로드 완료") | |
| except Exception as e: | |
| raise Exception(f"LLM 모델 로드 실패: {str(e)}") | |
| def search_relevant_docs(self, query: str, k: int = 3) -> List[Tuple[Document, float]]: | |
| """관련 문서 검색""" | |
| if not self.is_initialized: | |
| print("⚠️ 챗봇이 초기화되지 않았습니다.") | |
| return [] | |
| # 쿼리 전처리 | |
| processed_query = self._preprocess_query(query) | |
| # 벡터 검색 | |
| results = self.vector_store.search_similar(processed_query, k) | |
| # 유사도 필터링 | |
| filtered_results = [ | |
| (doc, similarity) for doc, similarity in results | |
| if similarity > 0.3 # 최소 유사도 임계값 | |
| ] | |
| return filtered_results | |
| def _preprocess_query(self, query: str) -> str: | |
| """쿼리 전처리""" | |
| # 불필요한 공백 제거 | |
| query = re.sub(r'\s+', ' ', query.strip()) | |
| # 복무관리 관련 키워드 강화 | |
| keyword_mappings = { | |
| "연차": "연차휴가", | |
| "휴가": "휴가사용", | |
| "근무": "근무시간", | |
| "당직": "당직근무", | |
| "인사": "인사평가", | |
| "승진": "승진시험" | |
| } | |
| for keyword, enhanced in keyword_mappings.items(): | |
| if keyword in query and enhanced not in query: | |
| query = query.replace(keyword, enhanced) | |
| return query | |
| def generate_answer(self, query: str, use_llm: bool = True) -> ChatResponse: | |
| """질문에 대한 답변 생성""" | |
| import time | |
| start_time = time.time() | |
| if not self.is_initialized: | |
| return ChatResponse( | |
| answer="죄송합니다. 챗봇이 초기화되지 않았습니다. 관리자에게 문의해주세요.", | |
| sources=[], | |
| confidence=0.0, | |
| response_time=time.time() - start_time | |
| ) | |
| # 1. 관련 문서 검색 | |
| relevant_docs = self.search_relevant_docs(query, k=Config.MAX_RETRIEVE_DOCS) | |
| if not relevant_docs: | |
| return ChatResponse( | |
| answer="죄송합니다. 질문과 관련된 정보를 찾을 수 없습니다. 다른 방식으로 질문해주시거나 관련 부서에 문의해주시기 바랍니다.", | |
| sources=[], | |
| confidence=0.0, | |
| response_time=time.time() - start_time | |
| ) | |
| # 2. 답변 생성 | |
| if use_llm and self.llm is not None: | |
| answer = self._generate_llm_answer(query, relevant_docs) | |
| else: | |
| answer = self._generate_template_answer(query, relevant_docs) | |
| # 3. 출처 정보 준비 | |
| sources = [ | |
| { | |
| "source": doc.metadata.get("source", "알 수 없음"), | |
| "content": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content, | |
| "similarity": f"{similarity:.4f}" | |
| } | |
| for doc, similarity in relevant_docs | |
| ] | |
| # 4. 신뢰도 계산 | |
| confidence = min(sum(similarity for _, similarity in relevant_docs) / len(relevant_docs), 1.0) | |
| return ChatResponse( | |
| answer=answer, | |
| sources=sources, | |
| confidence=confidence, | |
| response_time=time.time() - start_time | |
| ) | |
| def _generate_llm_answer(self, query: str, relevant_docs: List[Tuple[Document, float]]) -> str: | |
| """LLM으로 답변 생성""" | |
| try: | |
| # 문맥 구성 | |
| context = "\n\n".join([ | |
| f"[출처 {i+1}] {doc.page_content}" | |
| for i, (doc, _) in enumerate(relevant_docs) | |
| ]) | |
| # 프롬프트 구성 | |
| prompt = f"""{Config.SYSTEM_PROMPT} | |
| [참고자료] | |
| {context} | |
| [질문] | |
| {query} | |
| 위 참고자료를 바탕으로 질문에 답변해주세요. 정확하고 친절하게 설명해주세요.""" | |
| # 토크나이징 | |
| inputs = self.llm_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=2048, | |
| truncation=True | |
| ) | |
| # 생성 | |
| with torch.no_grad(): | |
| outputs = self.llm.generate( | |
| inputs.input_ids, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=self.llm_tokenizer.eos_token_id | |
| ) | |
| # 결과 디코딩 | |
| answer = self.llm_tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| return answer | |
| except Exception as e: | |
| print(f"⚠️ LLM 답변 생성 실패: {str(e)}") | |
| return self._generate_template_answer(query, relevant_docs) | |
| def _generate_template_answer(self, query: str, relevant_docs: List[Tuple[Document, float]]) -> str: | |
| """템플릿 기반 답변 생성""" | |
| # 쿼리 분석 | |
| query_lower = query.lower() | |
| # 가장 관련성 높은 문서 | |
| top_doc, top_similarity = relevant_docs[0] | |
| # 기본 답변 형식 | |
| if "연차" in query_lower or "휴가" in query_lower: | |
| return self._format_leave_answer(top_doc, query) | |
| elif "근무시간" in query_lower or "시간" in query_lower: | |
| return self._format_work_hours_answer(top_doc, query) | |
| elif "당직" in query_lower: | |
| return self._format_duty_answer(top_doc, query) | |
| elif "인사" in query_lower or "평가" in query_lower: | |
| return self._format_evaluation_answer(top_doc, query) | |
| else: | |
| return self._format_general_answer(top_doc, query) | |
| def _format_leave_answer(self, doc: Document, query: str) -> str: | |
| """휴가 관련 답변 형식""" | |
| content = doc.page_content | |
| answer = f"📅 연차휴가 안내\n\n" | |
| # 숫자와 관련된 내용 추출 | |
| import re | |
| days = re.findall(r'(\d+)일', content) | |
| periods = re.findall(r'(\d+)일 전', content) | |
| if days: | |
| answer += f"• 연차휴가 일수: {days[0]}일\n" | |
| if periods: | |
| answer += f"• 신청 기한: {periods[0]}일 전\n" | |
| answer += f"\n{content[:300]}..." | |
| if len(content) > 300: | |
| answer += "\n\n📋 자세한 내용은 관련 규정을 확인하시거나 인사담당자에게 문의해주세요." | |
| return answer | |
| def _format_work_hours_answer(self, doc: Document, query: str) -> str: | |
| """근무시간 관련 답변 형식""" | |
| content = doc.page_content | |
| answer = f"⏰ 근무시간 안내\n\n" | |
| answer += f"{content[:400]}..." | |
| # 시간 정보 추출 | |
| import re | |
| times = re.findall(r'\d{2}:\d{2}', content) | |
| if times: | |
| answer += f"\n\n🕐 주요 시간: {', '.join(times)}" | |
| return answer | |
| def _format_duty_answer(self, doc: Document, query: str) -> str: | |
| """당직 관련 답변 형식""" | |
| answer = f"🌙 당직근무 안내\n\n" | |
| answer += f"{doc.page_content[:400]}..." | |
| answer += "\n\n📞 당직 관련 추가 문의는 관리부서로 연락주세요." | |
| return answer | |
| def _format_evaluation_answer(self, doc: Document, query: str) -> str: | |
| """인사평가 관련 답변 형식""" | |
| answer = f"📊 인사평가 안내\n\n" | |
| answer += f"{doc.page_content[:400]}..." | |
| answer += "\n\n💡 평가 관련 구체적인 문의는 인사담당자에게 문의해주세요." | |
| return answer | |
| def _format_general_answer(self, doc: Document, query: str) -> str: | |
| """일반 답변 형식""" | |
| answer = f"📋 복무관리 안내\n\n" | |
| answer += f"질문: {query}\n\n" | |
| answer += f"관련 정보:\n{doc.page_content[:400]}..." | |
| if len(doc.page_content) > 400: | |
| answer += "\n\n📖 더 자세한 정보는 관련 규정 파일을 확인해주세요." | |
| return answer | |
| def get_stats(self) -> Dict: | |
| """챗봇 통계 정보""" | |
| if not self.is_initialized: | |
| return {"status": "not_initialized"} | |
| vector_stats = self.vector_store.get_stats() | |
| return { | |
| "status": "initialized", | |
| "vector_store": vector_stats, | |
| "llm_available": self.llm is not None, | |
| "system_prompt": Config.SYSTEM_PROMPT[:100] + "..." | |
| } | |
| # 테스트용 함수 | |
| def test_rag_chatbot(): | |
| """RAG 챗봇 테스트""" | |
| # 샘플 문서 폴더 확인 | |
| if not os.path.exists("documents"): | |
| print("⚠️ documents 폴더가 없습니다. document_processor.py를 먼저 실행해주세요.") | |
| return | |
| # 챗봇 초기화 | |
| chatbot = RAGChatbot() | |
| success = chatbot.initialize() | |
| if not success: | |
| return | |
| # 테스트 질문 | |
| test_questions = [ | |
| "연차휴가는 어떻게 사용하나요?", | |
| "정규근무시간은 어떻게 되나요?", | |
| "당직근무가 무엇인가요?", | |
| "인사평가 절차가 궁금합니다." | |
| ] | |
| # 질문 테스트 | |
| for question in test_questions: | |
| print(f"\n❓ 질문: {question}") | |
| response = chatbot.generate_answer(question, use_llm=False) # 템플릿 모드로 테스트 | |
| print(f"🤖 답변: {response.answer[:300]}...") | |
| print(f"📊 신뢰도: {response.confidence:.4f}") | |
| print(f"⏱️ 응답시간: {response.response_time:.4f}초") | |
| print(f"📚 출처: {len(response.sources)}개") | |
| # 통계 정보 | |
| print(f"\n📈 챗봇 통계: {chatbot.get_stats()}") | |
| if __name__ == "__main__": | |
| test_rag_chatbot() |