Jake-seong commited on
Commit
8866252
·
verified ·
1 Parent(s): 6bc792a

조회 단계에서 하이브리드 검색 도입

Browse files
Files changed (1) hide show
  1. app.py +138 -6
app.py CHANGED
@@ -3,10 +3,13 @@ import psycopg2
3
  from openai import OpenAI
4
  import json
5
  import os
6
- from typing import List, Dict
7
  from pgvector.psycopg2 import register_vector
8
  import numpy as np
9
  from datetime import datetime
 
 
 
10
 
11
  # DB 연결 설정
12
  def get_db_conn():
@@ -28,6 +31,80 @@ def get_embedding(text: str) -> List[float]:
28
  )
29
  return response.data[0].embedding
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def search_similar_chats(query: str, maxResults: int = 200) -> List[Dict]:
32
  """
33
  유사한 채팅 문서를 검색합니다.
@@ -37,13 +114,20 @@ def search_similar_chats(query: str, maxResults: int = 200) -> List[Dict]:
37
  Returns:
38
  List[Dict]: 검색 결과 목록
39
  """
40
- embedding = np.array(get_embedding(query))
 
 
 
 
 
 
 
41
  conn = get_db_conn()
42
  register_vector(conn)
43
 
44
  try:
45
  with conn.cursor() as cur:
46
- # 코사인 유사도 연산자 변경 (<=> 사용)
47
  cur.execute("""
48
  SELECT id, metadata, content,
49
  1 - (embedding <=> %s) AS similarity
@@ -53,12 +137,23 @@ def search_similar_chats(query: str, maxResults: int = 200) -> List[Dict]:
53
  """, (embedding, maxResults))
54
 
55
  rows = cur.fetchall()
56
- return [{
 
57
  "id": row[0],
58
  "metadata": row[1],
59
  "content": row[2],
60
  "similarity": float(row[3])
61
  } for row in rows]
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
  raise RuntimeError(f"DB 검색 오류: {str(e)}")
64
  finally:
@@ -87,7 +182,14 @@ def search_similar_chats_by_date(
87
  except ValueError as e:
88
  raise ValueError(f"날짜 형식 오류: {e}")
89
 
90
- embedding = np.array(get_embedding(query))
 
 
 
 
 
 
 
91
  conn = get_db_conn()
92
  register_vector(conn)
93
 
@@ -115,12 +217,42 @@ def search_similar_chats_by_date(
115
  cur.execute(base_query, tuple(params))
116
  rows = cur.fetchall()
117
 
118
- return [{
119
  "id": row[0],
120
  "metadata": row[1],
121
  "content": row[2],
122
  "similarity": float(row[3])
123
  } for row in rows]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  except Exception as e:
125
  raise RuntimeError(f"DB 검색 오류: {str(e)}")
126
  finally:
 
3
  from openai import OpenAI
4
  import json
5
  import os
6
+ from typing import List, Dict, Tuple, Any
7
  from pgvector.psycopg2 import register_vector
8
  import numpy as np
9
  from datetime import datetime
10
+ import re
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
 
14
  # DB 연결 설정
15
  def get_db_conn():
 
31
  )
32
  return response.data[0].embedding
33
 
34
+ def expand_query(query: str) -> str:
35
+ """
36
+ 사용자 쿼리를 확장하여 검색 품질을 개선합니다.
37
+ """
38
+ # GPT를 활용한 쿼리 확장
39
+ try:
40
+ response = client.chat.completions.create(
41
+ model="gpt-3.5-turbo",
42
+ messages=[
43
+ {"role": "system", "content": "당신은 검색 쿼리 확장 전문가입니다. 사용자의 쿼리를 분석하고, 이와 관련된 키워드와 질문 형태로 확장하세요."},
44
+ {"role": "user", "content": f"다음 검색어를 확장해주세요: '{query}'"}
45
+ ],
46
+ temperature=0.3,
47
+ max_tokens=150
48
+ )
49
+ expanded = query + " " + response.choices[0].message.content
50
+ return expanded
51
+ except:
52
+ # 오류 발생 시 원본 쿼리 반환
53
+ return query
54
+
55
+ def extract_keywords(text: str) -> List[str]:
56
+ """
57
+ 텍스트에서 중요 키워드를 추출합니다.
58
+ """
59
+ # 단순한 키워드 추출 (고급 NLP 라이브러리로 대체 가능)
60
+ # 불용어 제거 및 정규표현식으로 키워드 추출
61
+ stop_words = {'있는', '하는', '그리고', '입니다', '그것은', '있습니다', '합니다', '그런', '이런', '저런', '그냥'}
62
+ words = re.findall(r'\w+', text.lower())
63
+ keywords = [w for w in words if len(w) > 1 and w not in stop_words]
64
+ return list(set(keywords))
65
+
66
+ def perform_hybrid_search(
67
+ query: str,
68
+ vector_results: List[Dict],
69
+ keyword_weight: float = 0.3,
70
+ similarity_threshold: float = 0.4
71
+ ) -> List[Dict]:
72
+ """
73
+ 벡터 검색과 키워드 검색을 결합한 하이브리드 검색을 수행합니다.
74
+ """
75
+ # 임계값 미만의 결과 필터링
76
+ filtered_results = [r for r in vector_results if r["similarity"] >= similarity_threshold]
77
+
78
+ if not filtered_results:
79
+ # 결과가 없으면 임계값을 낮춰서 재시도
80
+ filtered_results = [r for r in vector_results if r["similarity"] >= similarity_threshold * 0.7]
81
+
82
+ if not filtered_results:
83
+ return vector_results[:5] # 여전히 없으면 상위 5개 반환
84
+
85
+ # 키워드 검색 가중치 적용
86
+ keywords = extract_keywords(query)
87
+
88
+ for result in filtered_results:
89
+ content = result.get("content", "")
90
+ keyword_matches = sum(1 for kw in keywords if kw.lower() in content.lower())
91
+ keyword_score = keyword_matches / max(len(keywords), 1)
92
+
93
+ # 최종 점수 계산 (벡터 유사도 + 키워드 가중치)
94
+ result["original_similarity"] = result["similarity"]
95
+ result["keyword_score"] = keyword_score
96
+ result["similarity"] = (1 - keyword_weight) * result["similarity"] + keyword_weight * keyword_score
97
+
98
+ # 최종 점수로 재정렬
99
+ return sorted(filtered_results, key=lambda x: x["similarity"], reverse=True)
100
+
101
+ def preprocess_query(query: str) -> str:
102
+ """
103
+ 검색 쿼리를 전처리하여 검색 품질을 개선합니다.
104
+ """
105
+ # 검색에 맞게 프롬프트 재구성
106
+ return f"다음 질문이나 주제와 관련된 대화를 찾아주세요: {query}"
107
+
108
  def search_similar_chats(query: str, maxResults: int = 200) -> List[Dict]:
109
  """
110
  유사한 채팅 문서를 검색합니다.
 
114
  Returns:
115
  List[Dict]: 검색 결과 목록
116
  """
117
+ # 쿼리 전처리 및 확장
118
+ processed_query = preprocess_query(query)
119
+ try:
120
+ expanded_query = expand_query(processed_query)
121
+ except:
122
+ expanded_query = processed_query
123
+
124
+ embedding = np.array(get_embedding(expanded_query))
125
  conn = get_db_conn()
126
  register_vector(conn)
127
 
128
  try:
129
  with conn.cursor() as cur:
130
+ # 코사인 유사도 계산
131
  cur.execute("""
132
  SELECT id, metadata, content,
133
  1 - (embedding <=> %s) AS similarity
 
137
  """, (embedding, maxResults))
138
 
139
  rows = cur.fetchall()
140
+
141
+ results = [{
142
  "id": row[0],
143
  "metadata": row[1],
144
  "content": row[2],
145
  "similarity": float(row[3])
146
  } for row in rows]
147
+
148
+ # 하��브리드 검색 적용
149
+ results = perform_hybrid_search(
150
+ query,
151
+ results,
152
+ keyword_weight=0.3,
153
+ similarity_threshold=0.4
154
+ )
155
+
156
+ return results
157
  except Exception as e:
158
  raise RuntimeError(f"DB 검색 오류: {str(e)}")
159
  finally:
 
182
  except ValueError as e:
183
  raise ValueError(f"날짜 형식 오류: {e}")
184
 
185
+ # 쿼리 전처리 및 확장
186
+ processed_query = preprocess_query(query)
187
+ try:
188
+ expanded_query = expand_query(processed_query)
189
+ except:
190
+ expanded_query = processed_query
191
+
192
+ embedding = np.array(get_embedding(expanded_query))
193
  conn = get_db_conn()
194
  register_vector(conn)
195
 
 
217
  cur.execute(base_query, tuple(params))
218
  rows = cur.fetchall()
219
 
220
+ results = [{
221
  "id": row[0],
222
  "metadata": row[1],
223
  "content": row[2],
224
  "similarity": float(row[3])
225
  } for row in rows]
226
+
227
+ # 하이브리드 검색 적용
228
+ results = perform_hybrid_search(
229
+ query,
230
+ results,
231
+ keyword_weight=0.3,
232
+ similarity_threshold=0.4
233
+ )
234
+
235
+ # 메타데이터 기반 가중치 적용
236
+ keywords = extract_keywords(query)
237
+ for result in results:
238
+ metadata = result.get("metadata", {})
239
+ if not metadata or isinstance(metadata, str):
240
+ continue
241
+
242
+ # 주제(topic) 필드에 키워드가 있는지 확인
243
+ topic = metadata.get("topic", "")
244
+ topic_matches = sum(1 for kw in keywords if kw.lower() in topic.lower())
245
+
246
+ # 주제 일치 가중치 적용
247
+ if topic_matches > 0:
248
+ topic_boost = 0.1 * min(topic_matches, 3) # 최대 0.3 가중치
249
+ result["similarity"] += topic_boost
250
+ result["topic_boost"] = topic_boost
251
+
252
+ # 결과 재정렬
253
+ results = sorted(results, key=lambda x: x["similarity"], reverse=True)
254
+
255
+ return results
256
  except Exception as e:
257
  raise RuntimeError(f"DB 검색 오류: {str(e)}")
258
  finally: