Spaces:
Build error
Build error
| # ./backend/app/rag_core.py | |
| import os | |
| import httpx | |
| from fastapi import HTTPException | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| # OLLAMA_API_BASE_URL 환경 변수 설정 | |
| OLLAMA_API_BASE_URL = os.getenv("OLLAMA_API_BASE_URL", "http://127.0.0.1:11434") | |
| # 전역 변수로 모델 로드 (앱 시작 시 한 번만 로드되도록) | |
| try: | |
| model = SentenceTransformer('jhgan/ko-sroberta-multitask', device='cpu') | |
| print("INFO: 임베딩 모델 'jhgan/ko-sroberta-multitask' 로드 완료.") | |
| except Exception as e: | |
| print(f"ERROR: 임베딩 모델 'jhgan/ko-sroberta-multitask' 로드 실패: {e}. 다국어 모델로 시도합니다.") | |
| try: | |
| model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device='cpu') | |
| print("INFO: 임베딩 모델 'sentence-transformers/paraphrase-multilingual-L12-v2' 로드 완료.") | |
| except Exception as e: | |
| print(f"ERROR: 대체 임베딩 모델 로드 실패: {e}. RAG 기능을 사용할 수 없습니다.") | |
| raise | |
| async def generate_answer_with_ollama(model_name: str, prompt: str) -> str: | |
| """ | |
| Ollama 서버에 질의하여 답변을 생성합니다. | |
| """ | |
| url = f"{OLLAMA_API_BASE_URL}/api/generate" | |
| headers = {"Content-Type": "application/json"} | |
| data = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| "stream": False | |
| } | |
| print(f"INFO: Ollama API 호출 시작. 모델: {model_name}") | |
| print(f"INFO: 프롬프트 미리보기: {prompt[:200]}...") | |
| try: | |
| async with httpx.AsyncClient(timeout=600.0) as client: | |
| response = await client.post(url, headers=headers, json=data) | |
| response.raise_for_status() | |
| response_data = response.json() | |
| full_response = response_data.get("response", "").strip() | |
| return full_response | |
| except httpx.HTTPStatusError as e: | |
| print(f"ERROR: Ollama API 호출 실패: {e}") | |
| raise HTTPException(status_code=500, detail="Ollama API 호출 실패") | |
| except httpx.RequestError as e: | |
| print(f"ERROR: 네트워크 오류: {e}") | |
| raise HTTPException(status_code=500, detail="네트워크 오류가 발생했습니다. 잠시 후 다시 시도해주세요.") | |
| except Exception as e: | |
| print(f"ERROR: 알 수 없는 오류: {e}") | |
| raise HTTPException(status_code=500, detail="알 수 없는 오류가 발생했습니다. 잠시 후 다시 시도해주세요.") | |
| async def perform_retrieval(chunks_with_timestamps: List[Dict], query: str, top_k: int = 5) -> List[Dict]: | |
| """ | |
| 제공된 텍스트 청크에서 쿼리와 가장 유사한 부분을 검색합니다. (Retrieval-only) | |
| """ | |
| if not chunks_with_timestamps: | |
| print("WARNING: RAG 검색을 위한 텍스트 청크가 없습니다.") | |
| return [] | |
| texts = [chunk["text"] for chunk in chunks_with_timestamps] | |
| print(f"INFO: 총 {len(texts)}개의 텍스트 청크 임베딩 시작.") | |
| try: | |
| chunk_embeddings = model.encode(texts, convert_to_numpy=True) | |
| except Exception as e: | |
| print(f"ERROR: 텍스트 청크 임베딩 중 오류 발생: {e}") | |
| return [] | |
| dimension = chunk_embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dimension) | |
| faiss.normalize_L2(chunk_embeddings) | |
| index.add(chunk_embeddings) | |
| query_embedding = model.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(query_embedding) | |
| similarities, indices = index.search(query_embedding, top_k) | |
| retrieved_chunks = [] | |
| MIN_SIMILARITY_THRESHOLD = 0.35 # 임계값 | |
| for i in range(len(indices[0])): | |
| idx = indices[0][i] | |
| original_chunk = chunks_with_timestamps[idx] | |
| score = float(similarities[0][i]) | |
| if score > MIN_SIMILARITY_THRESHOLD: | |
| retrieved_chunks.append({ | |
| "text": original_chunk["text"], | |
| "timestamp": original_chunk["timestamp"], | |
| "score": score, | |
| "start_seconds": original_chunk["start_seconds"] | |
| }) | |
| else: | |
| print(f"DEBUG: 유사도 임계값({MIN_SIMILARITY_THRESHOLD:.4f}) 미만으로 제외된 청크 (유사도: {score:.4f}): {original_chunk['text'][:50]}...") | |
| retrieved_chunks.sort(key=lambda x: x['start_seconds']) | |
| print(f"DEBUG: 최종 검색된 청크 수: {len(retrieved_chunks)}") | |
| return retrieved_chunks | |
| async def perform_rag_and_generate(query: str, chunks_with_timestamps: List[Dict], ollama_model_name: str, top_k: int = 50) -> Dict: | |
| """ | |
| RAG의 전체 프로세스(검색, 프롬프트 구성, 생성)를 수행합니다. | |
| """ | |
| # 1. RAG 검색 수행 | |
| retrieved_chunks = await perform_retrieval( | |
| chunks_with_timestamps=chunks_with_timestamps, | |
| query=query, | |
| top_k=top_k | |
| ) | |
| if not retrieved_chunks: | |
| return { | |
| "status": "error", | |
| "message": "검색 결과가 없습니다.", | |
| "retrieved_chunks": [], | |
| "generated_answer": "관련 정보를 찾지 못해 답변을 생성할 수 없습니다." | |
| } | |
| # 2. 검색 결과를 프롬프트에 추가 | |
| context = "\n\n".join([chunk["text"] for chunk in retrieved_chunks]) | |
| prompt = f"""당신은 유튜브 영상 내용을 완벽하게 이해하고 사용자의 질문에 답변하는 AI 어시스턴트입니다. | |
| 아래는 분석한 유튜브 영상의 자막 내용입니다. 이 정보를 바탕으로 사용자의 질문에 대해 상세하고 친절하게 답변하세요. | |
| 답변은 반드시 영상 내용에 근거해야 하며, 내용과 관련 없는 질문에는 '영상 내용과 관련이 없어 답변할 수 없습니다'라고 솔직하게 말해야 합니다. | |
| --- 유튜브 영상 자막 내용 --- | |
| {context} | |
| -------------------------- | |
| 사용자 질문: {query} | |
| 답변:""" | |
| # 3. Ollama 모델에 질의하여 답변 생성 | |
| generated_answer = await generate_answer_with_ollama( | |
| model_name=ollama_model_name, | |
| prompt=prompt | |
| ) | |
| return { | |
| "status": "success", | |
| "message": "성공적으로 영상을 처리하고 RAG 검색을 수행했습니다.", | |
| "retrieved_chunks": retrieved_chunks, | |
| "generated_answer": generated_answer | |
| } | |