ChatInsight / app.py
Jake-seong's picture
Update app.py
facba8d verified
raw
history blame
12.1 kB
import gradio as gr
import psycopg2
from openai import OpenAI
import json
import os
from typing import List, Dict
from pgvector.psycopg2 import register_vector
import numpy as np
# κ°€μ€‘μΉ˜ 및 μž„κ³„κ°’ μ„€μ •
DEFAULT_FULL_WEIGHT = 0.2
DEFAULT_TOPIC_WEIGHT = 0.5
DEFAULT_CUSTOMER_WEIGHT = 0.2
DEFAULT_AGENT_WEIGHT = 0.1
DEFAULT_SIMILARITY_THRESHOLD = 0.5
# DB μ—°κ²° μ„€μ •
def get_db_conn():
return psycopg2.connect(
host=os.environ["VECTOR_HOST"],
port=5432,
dbname=os.environ["VECTOR_DBNAME"],
user=os.environ["VECTOR_USER"],
password=os.environ["VECTOR_SECRET"]
)
# OpenAI ν΄λΌμ΄μ–ΈνŠΈ μ΄ˆκΈ°ν™”
client = OpenAI()
def get_embedding(text: str) -> List[float]:
"""
ν…μŠ€νŠΈλ₯Ό OpenAI의 text-embedding-ada-002 λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ μž„λ² λ”© λ²‘ν„°λ‘œ λ³€ν™˜ν•©λ‹ˆλ‹€.
Java의 float[](float32)와 ν˜Έν™˜λ˜λ„λ‘ λͺ…μ‹œμ μœΌλ‘œ float32둜 λ³€ν™˜ν•©λ‹ˆλ‹€.
Args:
text (str): μž„λ² λ”©ν•  ν…μŠ€νŠΈ
Returns:
List[float]: μž„λ² λ”© 벑터 (float32)
"""
try:
response = client.embeddings.create(
input=text,
model="text-embedding-ada-002",
encoding_format="float"
)
# λͺ…μ‹œμ μœΌλ‘œ float32둜 λ³€ν™˜ν•˜μ—¬ Java의 float[]와 ν˜Έν™˜λ˜κ²Œ 함
return np.array(response.data[0].embedding, dtype=np.float32).tolist()
except Exception as e:
print(f"μž„λ² λ”© 생성 쀑 였λ₯˜ λ°œμƒ: {str(e)}")
raise
def format_vector_for_pg(vector: List[float]) -> str:
"""
μž„λ² λ”© 벑터λ₯Ό PostgreSQL 포맷으둜 λ³€ν™˜ν•©λ‹ˆλ‹€.
μž…λ ₯된 벑터가 float32 νƒ€μž…μΈμ§€ ν™•μΈν•©λ‹ˆλ‹€.
"""
# 벑터가 float32 νƒ€μž…μΈμ§€ ν™•μΈν•˜κ³ , μ•„λ‹ˆλ©΄ λ³€ν™˜
# NumPy 배열이 μ•„λ‹Œ κ²½μš°μ—λ„ 처리
if not isinstance(vector, np.ndarray):
vector = np.array(vector, dtype=np.float32)
elif vector.dtype != np.float32:
vector = vector.astype(np.float32)
vector_str = ','.join([f"{x}" for x in vector])
return f"[{vector_str}]"
def get_text_value(node: Dict, field_name: str) -> str:
"""
λ”•μ…”λ„ˆλ¦¬μ—μ„œ ν…μŠ€νŠΈ 값을 μ•ˆμ „ν•˜κ²Œ μΆ”μΆœν•©λ‹ˆλ‹€.
μžλ°”μ˜ getTextValue() λ©”μ†Œλ“œμ™€ λ™μΌν•œ κΈ°λŠ₯μž…λ‹ˆλ‹€.
"""
if node and field_name in node and node[field_name] is not None:
return node[field_name]
return None
def search_similar_chat(query: str, max_results: int = 100) -> List[Dict]:
"""
μ±„νŒ… λ°μ΄ν„°μ—μ„œ μœ μ‚¬ν•œ μ½˜ν…μΈ λ₯Ό κ²€μƒ‰ν•©λ‹ˆλ‹€.
Args:
query (str): 검색할 쿼리 ν…μŠ€νŠΈ
max_results (int): λ°˜ν™˜ν•  μ΅œλŒ€ κ²°κ³Ό 수
Returns:
List[Dict]: 검색 κ²°κ³Ό λͺ©λ‘
"""
limit = max_results if max_results is not None else 100
# μžλ°”μ™€ λ™μΌν•œ κ°€μ€‘μΉ˜ μ„€μ •
full_w = DEFAULT_FULL_WEIGHT
topic_w = DEFAULT_TOPIC_WEIGHT
customer_w = DEFAULT_CUSTOMER_WEIGHT
agent_w = DEFAULT_AGENT_WEIGHT
threshold = DEFAULT_SIMILARITY_THRESHOLD
try:
# 쿼리 μž„λ² λ”© 생성
query_embedding = get_embedding(query)
# PostgreSQL 포맷으둜 벑터 λ³€ν™˜
query_vector = format_vector_for_pg(query_embedding)
# DB μ—°κ²°
conn = get_db_conn()
register_vector(conn)
# μžλ°” μ½”λ“œμ™€ λ™μΌν•œ SQL 쿼리 κ΅¬ν˜„
sql = """
WITH embeddings AS (
SELECT
id,
metadata,
content,
CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim,
CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim,
CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim,
CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim
FROM vector_store_multi_embeddings
WHERE full_embedding IS NOT NULL
OR topic_embedding IS NOT NULL
OR customer_embedding IS NOT NULL
OR agent_embedding IS NOT NULL
)
SELECT
id,
metadata,
content,
(full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity
FROM embeddings
ORDER BY combined_similarity DESC
LIMIT %s
""" % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w, limit)
with conn.cursor() as cur:
cur.execute(sql)
rows = cur.fetchall()
results = []
for row in rows:
id_val = row[0]
metadata_json = row[1]
content = row[2]
similarity_score = float(row[3])
# 메타데이터 νŒŒμ‹±
try:
metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json
result = {
"id": id_val,
"similarityScore": similarity_score,
"content": content,
"chatId": get_text_value(metadata, "chatId"),
"topic": get_text_value(metadata, "topic")
}
# μ‹œκ°„ ν•„λ“œ λ³€ν™˜ 없이 κ·ΈλŒ€λ‘œ μ‚¬μš©
if "startTime" in metadata and metadata["startTime"] is not None:
result["startTime"] = metadata["startTime"]
if "endTime" in metadata and metadata["endTime"] is not None:
result["endTime"] = metadata["endTime"]
results.append(result)
except Exception as e:
print(f"메타데이터 νŒŒμ‹± 였λ₯˜: {e}")
continue
# μž„κ³„κ°’ 필터링
filtered_results = [r for r in results if r["similarityScore"] >= threshold]
return filtered_results
except Exception as e:
print(f"닀쀑 μž„λ² λ”© 검색 쀑 였λ₯˜ λ°œμƒ: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
def search_similar_chat_by_date(
query: str,
start_date: str = None,
end_date: str = None,
max_results: int = 100
) -> List[Dict]:
"""
μ§€μ •λœ λ‚ μ§œ λ²”μœ„ λ‚΄μ˜ μ±„νŒ… 데이터λ₯Ό κ²€μƒ‰ν•©λ‹ˆλ‹€.
Args:
query (str): 검색할 쿼리 ν…μŠ€νŠΈ
start_date (str): 검색 μ‹œμž‘ λ‚ μ§œ (YYYY-MM-DD ν˜•μ‹)
end_date (str): 검색 μ’…λ£Œ λ‚ μ§œ (YYYY-MM-DD ν˜•μ‹)
max_results (int): λ°˜ν™˜ν•  μ΅œλŒ€ κ²°κ³Ό 수
Returns:
List[Dict]: 검색 κ²°κ³Ό λͺ©λ‘
"""
limit = max_results if max_results is not None else 100
# μžλ°”μ™€ λ™μΌν•œ κ°€μ€‘μΉ˜ μ„€μ •
full_w = DEFAULT_FULL_WEIGHT
topic_w = DEFAULT_TOPIC_WEIGHT
customer_w = DEFAULT_CUSTOMER_WEIGHT
agent_w = DEFAULT_AGENT_WEIGHT
threshold = DEFAULT_SIMILARITY_THRESHOLD
try:
# 쿼리 μž„λ² λ”© 생성
query_embedding = get_embedding(query)
# PostgreSQL 포맷으둜 벑터 λ³€ν™˜
query_vector = format_vector_for_pg(query_embedding)
# DB μ—°κ²°
conn = get_db_conn()
register_vector(conn)
# μžλ°” μ½”λ“œμ™€ λ™μΌν•œ SQL 쿼리 μ‹œμž‘
sql = """
WITH embeddings AS (
SELECT
id,
metadata,
content,
CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim,
CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim,
CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim,
CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim
FROM vector_store_multi_embeddings
WHERE full_embedding IS NOT NULL
OR topic_embedding IS NOT NULL
OR customer_embedding IS NOT NULL
OR agent_embedding IS NOT NULL
""" % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w)
# λ‚ μ§œ ν•„ν„° μΆ”κ°€
if start_date and start_date.strip():
# μ‹œμž‘ μ‹œκ°„ μΆ”κ°€ν•˜μ—¬ ISO ν˜•μ‹μœΌλ‘œ 비ꡐ
iso_start_date = start_date + "T00:00:00"
sql += f" AND metadata->>'startTime' >= '{iso_start_date}'"
if end_date and end_date.strip():
# μ’…λ£Œ μ‹œκ°„ μΆ”κ°€ν•˜μ—¬ ISO ν˜•μ‹μœΌλ‘œ 비ꡐ
iso_end_date = end_date + "T23:59:59"
sql += f" AND metadata->>'startTime' <= '{iso_end_date}'"
sql += """
)
SELECT
id,
metadata,
content,
(full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity
FROM embeddings
ORDER BY combined_similarity DESC
LIMIT %s
"""
with conn.cursor() as cur:
# μ—¬κΈ°μ„œλŠ” limitλ₯Ό νŒŒλΌλ―Έν„°λ‘œ 전달
cur.execute(sql, (limit,))
rows = cur.fetchall()
results = []
for row in rows:
id_val = row[0]
metadata_json = row[1]
content = row[2]
similarity_score = float(row[3])
# 메타데이터 νŒŒμ‹±
try:
metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json
result = {
"id": id_val,
"similarityScore": similarity_score,
"content": content,
"chatId": get_text_value(metadata, "chatId"),
"topic": get_text_value(metadata, "topic")
}
# μ‹œκ°„ ν•„λ“œ λ³€ν™˜ 없이 κ·ΈλŒ€λ‘œ μ‚¬μš© (이미 KST둜 μ €μž₯λ˜μ–΄ 있음)
if "startTime" in metadata and metadata["startTime"] is not None:
result["startTime"] = metadata["startTime"]
if "endTime" in metadata and metadata["endTime"] is not None:
result["endTime"] = metadata["endTime"]
results.append(result)
except Exception as e:
print(f"메타데이터 νŒŒμ‹± 였λ₯˜: {e}")
continue
# μž„κ³„κ°’ 필터링 (μžλ°” μ½”λ“œμ™€ λ™μΌν•˜κ²Œ κ΅¬ν˜„)
filtered_results = [r for r in results if r["similarityScore"] >= threshold]
return filtered_results
except Exception as e:
print(f"닀쀑 μž„λ² λ”© λ‚ μ§œ 검색 쀑 였λ₯˜ λ°œμƒ: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
# Gradio μ›Ή μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
with gr.Blocks() as demo:
gr.Markdown("# Chat Analysis Search")
gr.Interface(fn=search_similar_chat, inputs=["text", "number"], outputs="json", api_name="search_similar_chat")
gr.Interface(fn=search_similar_chat_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chat_by_date")
if __name__ == "__main__":
demo.launch(mcp_server=True)