Spaces:
Paused
Paused
| import gradio as gr | |
| import psycopg2 | |
| from openai import OpenAI | |
| import json | |
| import os | |
| from typing import List, Dict, Tuple, Any | |
| from pgvector.psycopg2 import register_vector | |
| import numpy as np | |
| from datetime import datetime | |
| import re | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # DB 연결 설정 | |
| def get_db_conn(): | |
| return psycopg2.connect( | |
| host=os.environ["VECTOR_HOST"], | |
| port=5432, | |
| dbname=os.environ["VECTOR_DBNAME"], | |
| user=os.environ["VECTOR_USER"], | |
| password=os.environ["VECTOR_SECRET"] | |
| ) | |
| client = OpenAI() | |
| def get_embedding(text: str) -> List[float]: | |
| """텍스트를 임베딩 벡터로 변환합니다.""" | |
| response = client.embeddings.create( | |
| input=text, | |
| model="text-embedding-3-small" | |
| ) | |
| return response.data[0].embedding | |
| def expand_query(query: str) -> str: | |
| """ | |
| 사용자 쿼리를 확장하여 검색 품질을 개선합니다. | |
| """ | |
| # GPT를 활용한 쿼리 확장 | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": "당신은 검색 쿼리 확장 전문가입니다. 사용자의 쿼리를 분석하고, 이와 관련된 키워드와 질문 형태로 확장하세요."}, | |
| {"role": "user", "content": f"다음 검색어를 확장해주세요: '{query}'"} | |
| ], | |
| temperature=0.3, | |
| max_tokens=150 | |
| ) | |
| expanded = query + " " + response.choices[0].message.content | |
| return expanded | |
| except: | |
| # 오류 발생 시 원본 쿼리 반환 | |
| return query | |
| def extract_keywords(text: str) -> List[str]: | |
| """ | |
| 텍스트에서 중요 키워드를 추출합니다. | |
| """ | |
| # 단순한 키워드 추출 (고급 NLP 라이브러리로 대체 가능) | |
| # 불용어 제거 및 정규표현식으로 키워드 추출 | |
| stop_words = {'있는', '하는', '그리고', '입니다', '그것은', '있습니다', '합니다', '그런', '이런', '저런', '그냥'} | |
| words = re.findall(r'\w+', text.lower()) | |
| keywords = [w for w in words if len(w) > 1 and w not in stop_words] | |
| return list(set(keywords)) | |
| def perform_hybrid_search( | |
| query: str, | |
| vector_results: List[Dict], | |
| keyword_weight: float = 0.3, | |
| similarity_threshold: float = 0.4 | |
| ) -> List[Dict]: | |
| """ | |
| 벡터 검색과 키워드 검색을 결합한 하이브리드 검색을 수행합니다. | |
| """ | |
| # 임계값 미만의 결과 필터링 | |
| filtered_results = [r for r in vector_results if r["similarity"] >= similarity_threshold] | |
| if not filtered_results: | |
| # 결과가 없으면 임계값을 낮춰서 재시도 | |
| filtered_results = [r for r in vector_results if r["similarity"] >= similarity_threshold * 0.7] | |
| if not filtered_results: | |
| return vector_results[:5] # 여전히 없으면 상위 5개 반환 | |
| # 키워드 검색 가중치 적용 | |
| keywords = extract_keywords(query) | |
| for result in filtered_results: | |
| content = result.get("content", "") | |
| keyword_matches = sum(1 for kw in keywords if kw.lower() in content.lower()) | |
| keyword_score = keyword_matches / max(len(keywords), 1) | |
| # 최종 점수 계산 (벡터 유사도 + 키워드 가중치) | |
| result["original_similarity"] = result["similarity"] | |
| result["keyword_score"] = keyword_score | |
| result["similarity"] = (1 - keyword_weight) * result["similarity"] + keyword_weight * keyword_score | |
| # 최종 점수로 재정렬 | |
| return sorted(filtered_results, key=lambda x: x["similarity"], reverse=True) | |
| def preprocess_query(query: str) -> str: | |
| """ | |
| 검색 쿼리를 전처리하여 검색 품질을 개선합니다. | |
| """ | |
| # 검색에 맞게 프롬프트 재구성 | |
| return f"다음 질문이나 주제와 관련된 대화를 찾아주세요: {query}" | |
| def search_similar_chats(query: str, maxResults: int = 200) -> List[Dict]: | |
| """ | |
| 유사한 채팅 문서를 검색합니다. | |
| Args: | |
| query (str): 검색할 쿼리 텍스트 | |
| maxResults (int): 반환할 최대 결과 수 | |
| Returns: | |
| List[Dict]: 검색 결과 목록 | |
| """ | |
| # 쿼리 전처리 및 확장 | |
| processed_query = preprocess_query(query) | |
| try: | |
| expanded_query = expand_query(processed_query) | |
| except: | |
| expanded_query = processed_query | |
| embedding = np.array(get_embedding(expanded_query)) | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| try: | |
| with conn.cursor() as cur: | |
| # 코사인 유사도 계산 | |
| cur.execute(""" | |
| SELECT id, metadata, content, | |
| 1 - (embedding <=> %s) AS similarity | |
| FROM vector_store | |
| ORDER BY similarity DESC | |
| LIMIT %s | |
| """, (embedding, maxResults)) | |
| rows = cur.fetchall() | |
| results = [{ | |
| "id": row[0], | |
| "metadata": row[1], | |
| "content": row[2], | |
| "similarity": float(row[3]) | |
| } for row in rows] | |
| # 하이브리드 검색 적용 | |
| results = perform_hybrid_search( | |
| query, | |
| results, | |
| keyword_weight=0.3, | |
| similarity_threshold=0.4 | |
| ) | |
| return results | |
| except Exception as e: | |
| raise RuntimeError(f"DB 검색 오류: {str(e)}") | |
| finally: | |
| conn.close() | |
| def search_similar_chats_by_date( | |
| query: str, | |
| startDate: str = None, | |
| endDate: str = None, | |
| maxResults: int = 200 | |
| ) -> List[Dict]: | |
| """ | |
| 지정된 날짜 범위에 해당하는 유사한 채팅 문서를 검색합니다. | |
| Args: | |
| query (str): 검색 쿼리 | |
| startDate (str): 검색 시작 날짜 (YYYY-MM-DD) | |
| endDate (str): 검색 종료 날짜 (YYYY-MM-DD) | |
| maxResults (int): 반환할 최대 결과 수 | |
| Returns: | |
| List[Dict]: 검색 결과 목록 | |
| """ | |
| try: | |
| start_dt = datetime.strptime(startDate, "%Y-%m-%d") if startDate else None | |
| end_dt = datetime.strptime(endDate, "%Y-%m-%d") if endDate else None | |
| except ValueError as e: | |
| raise ValueError(f"날짜 형식 오류: {e}") | |
| # 쿼리 전처리 및 확장 | |
| processed_query = preprocess_query(query) | |
| try: | |
| expanded_query = expand_query(processed_query) | |
| except: | |
| expanded_query = processed_query | |
| embedding = np.array(get_embedding(expanded_query)) | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| try: | |
| with conn.cursor() as cur: | |
| base_query = """ | |
| SELECT id, metadata, content, | |
| 1 - (embedding <=> %s) AS similarity | |
| FROM vector_store | |
| WHERE 1=1 | |
| """ | |
| params = [embedding] | |
| # 동적 쿼리 구성 | |
| if startDate: | |
| base_query += " AND (metadata->>'startTime')::date >= %s" | |
| params.append(startDate) | |
| if endDate: | |
| base_query += " AND (metadata->>'startTime')::date <= %s" | |
| params.append(endDate) | |
| base_query += " ORDER BY similarity DESC LIMIT %s" | |
| params.append(maxResults) | |
| cur.execute(base_query, tuple(params)) | |
| rows = cur.fetchall() | |
| results = [{ | |
| "id": row[0], | |
| "metadata": row[1], | |
| "content": row[2], | |
| "similarity": float(row[3]) | |
| } for row in rows] | |
| # 하이브리드 검색 적용 | |
| results = perform_hybrid_search( | |
| query, | |
| results, | |
| keyword_weight=0.3, | |
| similarity_threshold=0.4 | |
| ) | |
| # 메타데이터 기반 가중치 적용 | |
| keywords = extract_keywords(query) | |
| for result in results: | |
| metadata = result.get("metadata", {}) | |
| if not metadata or isinstance(metadata, str): | |
| continue | |
| # 주제(topic) 필드에 키워드가 있는지 확인 | |
| topic = metadata.get("topic", "") | |
| topic_matches = sum(1 for kw in keywords if kw.lower() in topic.lower()) | |
| # 주제 일치 가중치 적용 | |
| if topic_matches > 0: | |
| topic_boost = 0.1 * min(topic_matches, 3) # 최대 0.3 가중치 | |
| result["similarity"] += topic_boost | |
| result["topic_boost"] = topic_boost | |
| # 결과 재정렬 | |
| results = sorted(results, key=lambda x: x["similarity"], reverse=True) | |
| return results | |
| except Exception as e: | |
| raise RuntimeError(f"DB 검색 오류: {str(e)}") | |
| finally: | |
| conn.close() | |
| # Gradio Blocks에 함수 등록 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Chat Analysis Search") | |
| gr.Interface(fn=search_similar_chats, inputs=["text", "number"], outputs="json", api_name="search_similar_chats") | |
| gr.Interface(fn=search_similar_chats_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chats_by_date") | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) | |