Spaces:
Sleeping
Sleeping
| # rag_retriever.py | |
| import json | |
| import os | |
| import faiss | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from openai import OpenAI | |
| from huggingface_hub import hf_hub_download | |
| # --- μ€μ --- | |
| MODEL_NAME = 'jhgan/ko-sbert-nli' | |
| LLM_MODEL_NAME = 'gpt-3.5-turbo' | |
| DATA_REPO = "Syngyeon/seoulalpha-data" | |
| TOP_K = 10 | |
| # OpenAI ν΄λΌμ΄μΈνΈ μ΄κΈ°ν | |
| client = OpenAI(api_key=os.getenv("API_KEY")) | |
| # --- 리μμ€ λ‘λ© --- | |
| def _load_resources(): | |
| """λͺ¨λ λ‘λ© μ κ²μμ νμν 리μμ€λ₯Ό 미리 λΆλ¬μ΅λλ€.""" | |
| try: | |
| print("1. Hugging Face Hubμμ RAG 리μμ€λ₯Ό λ€μ΄λ‘λν©λλ€...") | |
| # HF repoμμ νμΌ λ€μ΄λ‘λ | |
| index_path = hf_hub_download(repo_id=DATA_REPO, repo_type="dataset", filename="data/faiss/faiss_merged_output/merged.index") | |
| metadata_path = hf_hub_download(repo_id=DATA_REPO, repo_type="dataset", filename="data/faiss/faiss_merged_output/merged_metadata.jsonl") | |
| # μλ² λ© λͺ¨λΈ λ‘λ | |
| model = SentenceTransformer(MODEL_NAME) | |
| # FAISS index λ‘λ | |
| index = faiss.read_index(index_path) | |
| # λ©νλ°μ΄ν° λ‘λ | |
| metadata_map = {} | |
| with open(metadata_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| meta = json.loads(line) | |
| metadata_map[meta['vector_id']] = meta | |
| print("RAG 리μμ€ λ‘λ© μλ£!") | |
| return model, index, metadata_map | |
| except Exception as e: | |
| print(f"RAG 리μμ€ λ‘λ©μ μ€ν¨νμ΅λλ€: {e}") | |
| return None, None, None | |
| # λͺ¨λμ΄ μν¬νΈλ λ 리μμ€λ₯Ό ν λ²λ§ λ‘λν©λλ€. | |
| embedding_model, faiss_index, meta_map = _load_resources() | |
| def _retrieve_places(query, k): | |
| """λ΄λΆ ν¨μ: 쿼리λ₯Ό κΈ°λ°μΌλ‘ μ μ¬ν μ₯μλ₯Ό κ²μν©λλ€.""" | |
| query_vector = embedding_model.encode([query]) | |
| distances, ids = faiss_index.search(query_vector.astype('float32'), k) | |
| results = [] | |
| for vector_id in ids[0]: | |
| if vector_id in meta_map: | |
| results.append(meta_map[vector_id]) | |
| return results | |
| def _generate_answer_with_llm(query, retrieved_places): | |
| """λ΄λΆ ν¨μ: κ²μλ μ 보λ₯Ό λ°νμΌλ‘ LLM λ΅λ³μ μμ±ν©λλ€.""" | |
| context = "" | |
| for i, place in enumerate(retrieved_places[:5]): # μμ 5κ° μ λ³΄λ§ μ¬μ© | |
| context += f"--- μ₯μ μ 보 {i+1} ---\n" | |
| context += f"μ΄λ¦: {place.get('name', 'μ 보 μμ')}\n" | |
| context += f"μ£Όμ: {place.get('address', 'μ 보 μμ')}\n" | |
| context += f"AI μμ½: {place.get('ai_summary', 'μ 보 μμ')}\n" | |
| processed_sentences = place.get('processed_sentences', []) | |
| context += "μ£Όμ νΉμ§ λ° νκΈ°:\n" | |
| for sentence in processed_sentences: | |
| context += f"- {sentence}\n" | |
| context += "\n" | |
| system_prompt = "λΉμ μ μ¬μ©μμ μ§λ¬Έμ κ°μ₯ μ ν©ν μ₯μλ₯Ό μΆμ²ν΄μ£Όλ μ μ©ν μ΄μμ€ν΄νΈμ λλ€." | |
| user_prompt = f""" | |
| μλ 'μ₯μ μ 보'λ§μ λ°νμΌλ‘ μ¬μ©μμ μ§λ¬Έμ λν λ΅λ³μ μμ±ν΄ μ£ΌμΈμ. | |
| [μ§μμ¬ν] | |
| 1. κ²μλ μ₯μ μ€μμ μ§λ¬Έκ³Ό κ°μ₯ κ΄λ ¨μ±μ΄ λμ 2~3κ³³μ μΆμ²ν΄ μ£ΌμΈμ. | |
| 2. κ° μ₯μλ₯Ό μΆμ²ν λ, λ°λμ 'μ΄λ¦'κ³Ό 'μ£Όμ'λ₯Ό λͺ ννκ² ν¨κ» νμν΄μ£ΌμΈμ. | |
| 3. κ° μ₯μλ₯Ό μΆμ²νλ μ΄μ λ₯Ό 'AI μμ½'κ³Ό 'μ£Όμ νΉμ§ λ° νκΈ°'λ₯Ό κ·Όκ±°λ‘ κ΅¬μ²΄μ μΌλ‘ μ€λͺ ν΄ μ£ΌμΈμ. | |
| 4. 'processed_sentences'μ μλ μ€μ νκΈ°λ₯Ό μΈμ©νμ¬ λ΅λ³νλ©΄ μ λ’°λλ₯Ό λμΌ μ μμ΅λλ€. | |
| 5. μΉμ νκ³ μμ°μ€λ¬μ΄ λ§ν¬λ‘ λ΅λ³ν΄ μ£ΌμΈμ. | |
| --- μ₯μ μ 보 --- | |
| {context} | |
| --- μ¬μ©μμ μ§λ¬Έ --- | |
| {query} | |
| """ | |
| try: | |
| response = client.chat.completions.create( | |
| model=LLM_MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0.7, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"LLM λ΅λ³ μμ± μ€ μ€λ₯κ° λ°μνμ΅λλ€: {e}" | |
| # --- λν μ€ν ν¨μ --- | |
| def get_rag_recommendation(search_query, region_keywords): | |
| """ | |
| κ²μ 쿼리μ μ§μ ν€μλλ₯Ό λ°μ RAG μμ€ν μ ν΅ν΄ μ΅μ’ μΆμ² λ΅λ³μ λ°νν©λλ€. | |
| """ | |
| if not all([embedding_model, faiss_index, meta_map]): | |
| return "RAG μμ€ν μ΄ μ€λΉλμ§ μμ μΆμ²μ μμ±ν μ μμ΅λλ€." | |
| # 1. μ₯μ κ²μ | |
| print("\n[RAG] μλ―Έμ μΌλ‘ μ μ¬ν μ₯μλ₯Ό κ²μν©λλ€...") | |
| top_places = _retrieve_places(search_query, k=100) | |
| if not top_places: | |
| return "κ΄λ ¨λ μ₯μλ₯Ό μ°Ύμ§ λͺ»νμ΅λλ€." | |
| # 2. μ§μ νν°λ§ | |
| if region_keywords: | |
| print(f"[RAG] μ£Όμ νν°λ§ (ν€μλ: {region_keywords})...") | |
| filtered_places = [] | |
| for place in top_places: | |
| address = place.get('address', '') | |
| if any(keyword in address for keyword in region_keywords): | |
| filtered_places.append(place) | |
| if len(filtered_places) >= 10: | |
| break | |
| print(f"[RAG] νν°λ§ ν λ¨μ μ₯μ: {[p.get('name') for p in filtered_places]}") | |
| else: | |
| print("[RAG] μ§μ ν€μλκ° μμ΄ νν°λ§μ 건λλλλ€.") | |
| filtered_places = top_places | |
| if not filtered_places: | |
| return "μμ²νμ μ§μμ λ§λ μ₯μλ₯Ό μ°Ύμ§ λͺ»νμ΅λλ€." | |
| # 3. LLMμΌλ‘ λ΅λ³ μμ± | |
| print("[RAG] νν°λ§λ μ 보λ₯Ό λ°νμΌλ‘ μ΅μ’ λ΅λ³μ μμ±ν©λλ€...") | |
| final_answer = _generate_answer_with_llm(search_query, filtered_places) | |
| return final_answer | |