seoulalpha / rag_retriever.py
SyngyeonTak
repo_type change
64fa191
# rag_retriever.py
import json
import os
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from openai import OpenAI
from huggingface_hub import hf_hub_download
# --- μ„€μ • ---
MODEL_NAME = 'jhgan/ko-sbert-nli'
LLM_MODEL_NAME = 'gpt-3.5-turbo'
DATA_REPO = "Syngyeon/seoulalpha-data"
TOP_K = 10
# OpenAI ν΄λΌμ΄μ–ΈνŠΈ μ΄ˆκΈ°ν™”
client = OpenAI(api_key=os.getenv("API_KEY"))
# --- λ¦¬μ†ŒμŠ€ λ‘œλ”© ---
def _load_resources():
"""λͺ¨λ“ˆ λ‘œλ”© μ‹œ 검색에 ν•„μš”ν•œ λ¦¬μ†ŒμŠ€λ₯Ό 미리 λΆˆλŸ¬μ˜΅λ‹ˆλ‹€."""
try:
print("1. Hugging Face Hubμ—μ„œ RAG λ¦¬μ†ŒμŠ€λ₯Ό λ‹€μš΄λ‘œλ“œν•©λ‹ˆλ‹€...")
# HF repoμ—μ„œ 파일 λ‹€μš΄λ‘œλ“œ
index_path = hf_hub_download(repo_id=DATA_REPO, repo_type="dataset", filename="data/faiss/faiss_merged_output/merged.index")
metadata_path = hf_hub_download(repo_id=DATA_REPO, repo_type="dataset", filename="data/faiss/faiss_merged_output/merged_metadata.jsonl")
# μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ
model = SentenceTransformer(MODEL_NAME)
# FAISS index λ‘œλ“œ
index = faiss.read_index(index_path)
# 메타데이터 λ‘œλ“œ
metadata_map = {}
with open(metadata_path, 'r', encoding='utf-8') as f:
for line in f:
meta = json.loads(line)
metadata_map[meta['vector_id']] = meta
print("RAG λ¦¬μ†ŒμŠ€ λ‘œλ”© μ™„λ£Œ!")
return model, index, metadata_map
except Exception as e:
print(f"RAG λ¦¬μ†ŒμŠ€ λ‘œλ”©μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€: {e}")
return None, None, None
# λͺ¨λ“ˆμ΄ μž„ν¬νŠΈλ  λ•Œ λ¦¬μ†ŒμŠ€λ₯Ό ν•œ 번만 λ‘œλ“œν•©λ‹ˆλ‹€.
embedding_model, faiss_index, meta_map = _load_resources()
def _retrieve_places(query, k):
"""λ‚΄λΆ€ ν•¨μˆ˜: 쿼리λ₯Ό 기반으둜 μœ μ‚¬ν•œ μž₯μ†Œλ₯Ό κ²€μƒ‰ν•©λ‹ˆλ‹€."""
query_vector = embedding_model.encode([query])
distances, ids = faiss_index.search(query_vector.astype('float32'), k)
results = []
for vector_id in ids[0]:
if vector_id in meta_map:
results.append(meta_map[vector_id])
return results
def _generate_answer_with_llm(query, retrieved_places):
"""λ‚΄λΆ€ ν•¨μˆ˜: κ²€μƒ‰λœ 정보λ₯Ό λ°”νƒ•μœΌλ‘œ LLM 닡변을 μƒμ„±ν•©λ‹ˆλ‹€."""
context = ""
for i, place in enumerate(retrieved_places[:5]): # μƒμœ„ 5개 μ •λ³΄λ§Œ μ‚¬μš©
context += f"--- μž₯μ†Œ 정보 {i+1} ---\n"
context += f"이름: {place.get('name', '정보 μ—†μŒ')}\n"
context += f"μ£Όμ†Œ: {place.get('address', '정보 μ—†μŒ')}\n"
context += f"AI μš”μ•½: {place.get('ai_summary', '정보 μ—†μŒ')}\n"
processed_sentences = place.get('processed_sentences', [])
context += "μ£Όμš” νŠΉμ§• 및 ν›„κΈ°:\n"
for sentence in processed_sentences:
context += f"- {sentence}\n"
context += "\n"
system_prompt = "당신은 μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ— κ°€μž₯ μ ν•©ν•œ μž₯μ†Œλ₯Ό μΆ”μ²œν•΄μ£ΌλŠ” μœ μš©ν•œ μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€."
user_prompt = f"""
μ•„λž˜ 'μž₯μ†Œ 정보'λ§Œμ„ λ°”νƒ•μœΌλ‘œ μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ— λŒ€ν•œ 닡변을 생성해 μ£Όμ„Έμš”.
[μ§€μ‹œμ‚¬ν•­]
1. κ²€μƒ‰λœ μž₯μ†Œ μ€‘μ—μ„œ 질문과 κ°€μž₯ 관련성이 높은 2~3곳을 μΆ”μ²œν•΄ μ£Όμ„Έμš”.
2. 각 μž₯μ†Œλ₯Ό μΆ”μ²œν•  λ•Œ, λ°˜λ“œμ‹œ '이름'κ³Ό 'μ£Όμ†Œ'λ₯Ό λͺ…ν™•ν•˜κ²Œ ν•¨κ»˜ ν‘œμ‹œν•΄μ£Όμ„Έμš”.
3. 각 μž₯μ†Œλ₯Ό μΆ”μ²œν•˜λŠ” 이유λ₯Ό 'AI μš”μ•½'κ³Ό 'μ£Όμš” νŠΉμ§• 및 ν›„κΈ°'λ₯Ό 근거둜 ꡬ체적으둜 μ„€λͺ…ν•΄ μ£Όμ„Έμš”.
4. 'processed_sentences'에 μžˆλŠ” μ‹€μ œ ν›„κΈ°λ₯Ό μΈμš©ν•˜μ—¬ λ‹΅λ³€ν•˜λ©΄ 신뒰도λ₯Ό 높일 수 μžˆμŠ΅λ‹ˆλ‹€.
5. μΉœμ ˆν•˜κ³  μžμ—°μŠ€λŸ¬μš΄ 말투둜 λ‹΅λ³€ν•΄ μ£Όμ„Έμš”.
--- μž₯μ†Œ 정보 ---
{context}
--- μ‚¬μš©μžμ˜ 질문 ---
{query}
"""
try:
response = client.chat.completions.create(
model=LLM_MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.7,
)
return response.choices[0].message.content
except Exception as e:
return f"LLM λ‹΅λ³€ 생성 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {e}"
# --- λŒ€ν‘œ μ‹€ν–‰ ν•¨μˆ˜ ---
def get_rag_recommendation(search_query, region_keywords):
"""
검색 쿼리와 μ§€μ—­ ν‚€μ›Œλ“œλ₯Ό λ°›μ•„ RAG μ‹œμŠ€ν…œμ„ 톡해 μ΅œμ’… μΆ”μ²œ 닡변을 λ°˜ν™˜ν•©λ‹ˆλ‹€.
"""
if not all([embedding_model, faiss_index, meta_map]):
return "RAG μ‹œμŠ€ν…œμ΄ μ€€λΉ„λ˜μ§€ μ•Šμ•„ μΆ”μ²œμ„ 생성할 수 μ—†μŠ΅λ‹ˆλ‹€."
# 1. μž₯μ†Œ 검색
print("\n[RAG] 의미적으둜 μœ μ‚¬ν•œ μž₯μ†Œλ₯Ό κ²€μƒ‰ν•©λ‹ˆλ‹€...")
top_places = _retrieve_places(search_query, k=100)
if not top_places:
return "κ΄€λ ¨λœ μž₯μ†Œλ₯Ό μ°Ύμ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€."
# 2. μ§€μ—­ 필터링
if region_keywords:
print(f"[RAG] μ£Όμ†Œ 필터링 (ν‚€μ›Œλ“œ: {region_keywords})...")
filtered_places = []
for place in top_places:
address = place.get('address', '')
if any(keyword in address for keyword in region_keywords):
filtered_places.append(place)
if len(filtered_places) >= 10:
break
print(f"[RAG] 필터링 ν›„ 남은 μž₯μ†Œ: {[p.get('name') for p in filtered_places]}")
else:
print("[RAG] μ§€μ—­ ν‚€μ›Œλ“œκ°€ μ—†μ–΄ 필터링을 κ±΄λ„ˆλœλ‹ˆλ‹€.")
filtered_places = top_places
if not filtered_places:
return "μš”μ²­ν•˜μ‹  지역에 λ§žλŠ” μž₯μ†Œλ₯Ό μ°Ύμ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€."
# 3. LLM으둜 λ‹΅λ³€ 생성
print("[RAG] ν•„ν„°λ§λœ 정보λ₯Ό λ°”νƒ•μœΌλ‘œ μ΅œμ’… 닡변을 μƒμ„±ν•©λ‹ˆλ‹€...")
final_answer = _generate_answer_with_llm(search_query, filtered_places)
return final_answer