ChatInsight / app.py
Jake-seong's picture
Update app.py
70b2365 verified
raw
history blame
14.2 kB
import gradio as gr
import psycopg2
from openai import OpenAI
import json
import os
from typing import List, Dict, Tuple, Any
from pgvector.psycopg2 import register_vector
import numpy as np
from datetime import datetime
from sklearn.preprocessing import normalize
# ๊ฐ€์ค‘์น˜ ๋ฐ ์ž„๊ณ„๊ฐ’ ์„ค์ •
DEFAULT_FULL_WEIGHT = 0.2
DEFAULT_TOPIC_WEIGHT = 0.5
DEFAULT_CUSTOMER_WEIGHT = 0.2
DEFAULT_AGENT_WEIGHT = 0.1
DEFAULT_SIMILARITY_THRESHOLD = 0
# DB ์—ฐ๊ฒฐ ์„ค์ •
def get_db_conn():
return psycopg2.connect(
host=os.environ["VECTOR_HOST"],
port=5432,
dbname=os.environ["VECTOR_DBNAME"],
user=os.environ["VECTOR_USER"],
password=os.environ["VECTOR_SECRET"]
)
client = OpenAI()
def get_embedding(text: str) -> List[float]:
"""ํ…์ŠคํŠธ๋ฅผ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
response = client.embeddings.create(
input=text,
model="text-embedding-3-small"
)
return response.data[0].embedding
def get_text_value(node, field_name):
"""JSON ๋…ธ๋“œ์—์„œ ํ…์ŠคํŠธ ๊ฐ’์„ ์•ˆ์ „ํ•˜๊ฒŒ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค."""
if node and field_name in node and node[field_name] is not None:
return node[field_name]
return None
def format_vector_for_pg(vector: List[float]) -> str:
"""๋ฒกํ„ฐ๋ฅผ PostgreSQL ํฌ๋งท์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
return f"[{','.join(str(x) for x in vector)}]"
def search_similar_chat(query: str, max_results: int = 100) -> List[Dict]:
"""
๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ๋œ ์ฑ„ํŒ… ๋ฐ์ดํ„ฐ์—์„œ ์œ ์‚ฌํ•œ ์ฝ˜ํ…์ธ ๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.
Args:
query (str): ๊ฒ€์ƒ‰ํ•  ์ฟผ๋ฆฌ ํ…์ŠคํŠธ
max_results (int): ๋ฐ˜ํ™˜ํ•  ์ตœ๋Œ€ ๊ฒฐ๊ณผ ์ˆ˜
Returns:
List[Dict]: ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ชฉ๋ก
"""
limit = max_results if max_results is not None else 100
# ๊ฐ€์ค‘์น˜ ์„ค์ •
full_w = DEFAULT_FULL_WEIGHT
topic_w = DEFAULT_TOPIC_WEIGHT
customer_w = DEFAULT_CUSTOMER_WEIGHT
agent_w = DEFAULT_AGENT_WEIGHT
threshold = DEFAULT_SIMILARITY_THRESHOLD
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰ ์‹œ์ž‘: ์ฟผ๋ฆฌ='{query}', ๊ฐ€์ค‘์น˜=(full={full_w}, topic={topic_w}, customer={customer_w}, agent={agent_w}), ์ตœ๋Œ€ ๊ฒฐ๊ณผ={limit}")
try:
# ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
raw_embedding = np.array(get_embedding(query))
# L2 ์ •๊ทœํ™” ์ ์šฉ
query_embedding = normalize(raw_embedding.reshape(1, -1), norm='l2')[0]
print(f"์ž„๋ฒ ๋”ฉ ์ •๊ทœํ™” ์ „/ํ›„ ์ฒซ 5๊ฐœ ์š”์†Œ: {raw_embedding[:5]} -> {query_embedding[:5]}")
# Java ๋ฐฉ์‹: ๋ฒกํ„ฐ๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
query_vector = format_vector_for_pg(query_embedding)
# DB ์—ฐ๊ฒฐ
conn = get_db_conn()
register_vector(conn)
# Java ๋ฐฉ์‹: ๋ฌธ์ž์—ด ํฌ๋งทํŒ… ์‚ฌ์šฉํ•œ SQL ์ฟผ๋ฆฌ
sql = f"""
WITH embeddings AS (
SELECT
id,
metadata,
content,
CASE WHEN full_embedding IS NOT NULL THEN (full_embedding <=> '{query_vector}'::vector) ELSE 0 END * {full_w} as full_sim,
CASE WHEN topic_embedding IS NOT NULL THEN (topic_embedding <=> '{query_vector}'::vector) ELSE 0 END * {topic_w} as topic_sim,
CASE WHEN customer_embedding IS NOT NULL THEN (customer_embedding <=> '{query_vector}'::vector) ELSE 0 END * {customer_w} as customer_sim,
CASE WHEN agent_embedding IS NOT NULL THEN (agent_embedding <=> '{query_vector}'::vector) ELSE 0 END * {agent_w} as agent_sim
FROM vector_store_multi_embeddings
WHERE full_embedding IS NOT NULL
OR topic_embedding IS NOT NULL
OR customer_embedding IS NOT NULL
OR agent_embedding IS NOT NULL
)
SELECT
id,
metadata,
content,
(full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity
FROM embeddings
ORDER BY combined_similarity DESC
LIMIT {limit}
"""
with conn.cursor() as cur:
print(f"์ฟผ๋ฆฌ ์‹คํ–‰ - Java ๋ฐฉ์‹ ํฌ๋งทํŒ…, ๊ฐ€์ค‘์น˜ ์„ค์ •={full_w}, {topic_w}, {customer_w}, {agent_w}, ๊ฒฐ๊ณผ ์ œํ•œ={limit}")
# Java ๋ฐฉ์‹: ๋งค๊ฐœ๋ณ€์ˆ˜ ์—†์ด ์ง์ ‘ ์ฟผ๋ฆฌ ์‹คํ–‰
cur.execute(sql)
rows = cur.fetchall()
print(f"๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ: ์ด {len(rows)}๊ฐœ ๋ฐ์ดํ„ฐ ์กฐํšŒ๋จ")
if len(rows) > 0:
print(f"์ฒซ ๋ฒˆ์งธ ๊ฒฐ๊ณผ ID: {rows[0][0]}, ์œ ์‚ฌ๋„: {float(rows[0][3])}")
results = []
for row in rows:
id_val = row[0]
metadata_json = row[1]
content = row[2]
similarity_score = float(row[3])
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ
try:
metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json
result = {
"id": id_val,
"similarityScore": similarity_score,
"content": content,
"chatId": get_text_value(metadata, "chatId"),
"topic": get_text_value(metadata, "topic")
}
# ์‹œ๊ฐ„ ํ•„๋“œ ๋ณ€ํ™˜ ์—†์ด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
if "startTime" in metadata and metadata["startTime"] is not None:
result["startTime"] = metadata["startTime"]
if "endTime" in metadata and metadata["endTime"] is not None:
result["endTime"] = metadata["endTime"]
results.append(result)
except Exception as e:
print(f"๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ ์˜ค๋ฅ˜: {e}")
print(f"๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ: {metadata_json[:200]}...")
continue
# ์ž„๊ณ„๊ฐ’ ํ•„ํ„ฐ๋ง
filtered_results = [r for r in results if r["similarityScore"] >= threshold]
print(f"์ž„๊ณ„๊ฐ’({threshold}) ์ด์ƒ ๊ฒฐ๊ณผ: {len(filtered_results)}๊ฐœ / ์ „์ฒด {len(results)}๊ฐœ")
if len(filtered_results) > 0:
print(f"๊ฐ€์žฅ ๋†’์€ ์œ ์‚ฌ๋„ ์ ์ˆ˜: {filtered_results[0]['similarityScore']}")
print(f"์ƒ์œ„ ๊ฒฐ๊ณผ ์ฑ—ID: {filtered_results[0].get('chatId')}, ์ฃผ์ œ: {filtered_results[0].get('topic', '')[:50]}...")
return filtered_results
except Exception as e:
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
def search_similar_chat_by_date(
query: str,
start_date: str = None,
end_date: str = None,
max_results: int = 100
) -> List[Dict]:
"""
์ง€์ •๋œ ๋‚ ์งœ ๋ฒ”์œ„ ๋‚ด์˜ ๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ์ฑ„ํŒ… ๋ฐ์ดํ„ฐ๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.
Args:
query (str): ๊ฒ€์ƒ‰ํ•  ์ฟผ๋ฆฌ ํ…์ŠคํŠธ
start_date (str): ๊ฒ€์ƒ‰ ์‹œ์ž‘ ๋‚ ์งœ (YYYY-MM-DD ํ˜•์‹)
end_date (str): ๊ฒ€์ƒ‰ ์ข…๋ฃŒ ๋‚ ์งœ (YYYY-MM-DD ํ˜•์‹)
max_results (int): ๋ฐ˜ํ™˜ํ•  ์ตœ๋Œ€ ๊ฒฐ๊ณผ ์ˆ˜
Returns:
List[Dict]: ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ชฉ๋ก
"""
limit = max_results if max_results is not None else 100
# ๊ฐ€์ค‘์น˜ ์„ค์ •
full_w = DEFAULT_FULL_WEIGHT
topic_w = DEFAULT_TOPIC_WEIGHT
customer_w = DEFAULT_CUSTOMER_WEIGHT
agent_w = DEFAULT_AGENT_WEIGHT
threshold = DEFAULT_SIMILARITY_THRESHOLD
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๋‚ ์งœ ๊ฒ€์ƒ‰ ์‹œ์ž‘: ์ฟผ๋ฆฌ='{query}', ์‹œ์ž‘์ผ={start_date}, ์ข…๋ฃŒ์ผ={end_date}, ์ตœ๋Œ€ ๊ฒฐ๊ณผ={limit}")
try:
# ๋‚ ์งœ ํ•„ํ„ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ์ƒ์„ฑ
start_timestamp = None
end_timestamp = None
if start_date and start_date.strip():
try:
start_datetime = datetime.strptime(start_date, '%Y-%m-%d')
start_timestamp = int(start_datetime.timestamp() * 1000) # ๋ฐ€๋ฆฌ์ดˆ ๋‹จ์œ„๋กœ ๋ณ€ํ™˜
except ValueError as e:
print(f"์‹œ์ž‘ ๋‚ ์งœ ํ˜•์‹ ์˜ค๋ฅ˜: {str(e)}")
return []
if end_date and end_date.strip():
try:
# ์ข…๋ฃŒ์ผ์˜ 23:59:59๋กœ ์„ค์ •
end_datetime = datetime.strptime(end_date + ' 23:59:59', '%Y-%m-%d %H:%M:%S')
end_timestamp = int(end_datetime.timestamp() * 1000) # ๋ฐ€๋ฆฌ์ดˆ ๋‹จ์œ„๋กœ ๋ณ€ํ™˜
except ValueError as e:
print(f"์ข…๋ฃŒ ๋‚ ์งœ ํ˜•์‹ ์˜ค๋ฅ˜: {str(e)}")
return []
# ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
raw_embedding = np.array(get_embedding(query))
# L2 ์ •๊ทœํ™” ์ ์šฉ
query_embedding = normalize(raw_embedding.reshape(1, -1), norm='l2')[0]
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ์ž„๋ฒ ๋”ฉ ์ •๊ทœํ™” ์ „/ํ›„ ์ฒซ 5๊ฐœ ์š”์†Œ: {raw_embedding[:5]} -> {query_embedding[:5]}")
# Java ๋ฐฉ์‹: ๋ฒกํ„ฐ๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
query_vector = format_vector_for_pg(query_embedding)
# DB ์—ฐ๊ฒฐ
conn = get_db_conn()
register_vector(conn)
# Java ๋ฐฉ์‹: ๋ฌธ์ž์—ด ํฌ๋งทํŒ… ์‚ฌ์šฉํ•œ SQL ์ฟผ๋ฆฌ ์‹œ์ž‘
sql = f"""
WITH embeddings AS (
SELECT
id,
metadata,
content,
CASE WHEN full_embedding IS NOT NULL THEN (full_embedding <=> '{query_vector}'::vector) ELSE 0 END * {full_w} as full_sim,
CASE WHEN topic_embedding IS NOT NULL THEN (topic_embedding <=> '{query_vector}'::vector) ELSE 0 END * {topic_w} as topic_sim,
CASE WHEN customer_embedding IS NOT NULL THEN (customer_embedding <=> '{query_vector}'::vector) ELSE 0 END * {customer_w} as customer_sim,
CASE WHEN agent_embedding IS NOT NULL THEN (agent_embedding <=> '{query_vector}'::vector) ELSE 0 END * {agent_w} as agent_sim
FROM vector_store_multi_embeddings
WHERE full_embedding IS NOT NULL
OR topic_embedding IS NOT NULL
OR customer_embedding IS NOT NULL
OR agent_embedding IS NOT NULL
"""
# ๋‚ ์งœ ํ•„ํ„ฐ ์ถ”๊ฐ€
if start_timestamp is not None:
sql += f" AND (metadata->>'startTime')::bigint >= {start_timestamp}"
if end_timestamp is not None:
sql += f" AND (metadata->>'startTime')::bigint <= {end_timestamp}"
sql += """
)
SELECT
id,
metadata,
content,
(full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity
FROM embeddings
ORDER BY combined_similarity DESC
LIMIT %s
"""
with conn.cursor() as cur:
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ ์‹คํ–‰: ์‹œ์ž‘์ผ={start_date}({start_timestamp}), ์ข…๋ฃŒ์ผ={end_date}({end_timestamp})")
# ์—ฌ๊ธฐ์„œ๋Š” limit๋งŒ ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ์ „๋‹ฌ
cur.execute(sql, (limit,))
rows = cur.fetchall()
print(f"๋‚ ์งœ ํ•„ํ„ฐ๋ง ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ: ์ด {len(rows)}๊ฐœ ๋ฐ์ดํ„ฐ ์กฐํšŒ๋จ")
if len(rows) > 0:
print(f"์ฒซ ๋ฒˆ์งธ ๊ฒฐ๊ณผ ID: {rows[0][0]}, ์œ ์‚ฌ๋„: {float(rows[0][3])}")
results = []
for row in rows:
id_val = row[0]
metadata_json = row[1]
content = row[2]
similarity_score = float(row[3])
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ
try:
metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json
result = {
"id": id_val,
"similarityScore": similarity_score,
"content": content,
"chatId": get_text_value(metadata, "chatId"),
"topic": get_text_value(metadata, "topic")
}
# ์‹œ๊ฐ„ ํ•„๋“œ ๋ณ€ํ™˜ ์—†์ด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
if "startTime" in metadata and metadata["startTime"] is not None:
result["startTime"] = metadata["startTime"]
if "endTime" in metadata and metadata["endTime"] is not None:
result["endTime"] = metadata["endTime"]
results.append(result)
except Exception as e:
print(f"๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ ์˜ค๋ฅ˜: {e}")
print(f"๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ: {metadata_json[:200]}...")
continue
# ์ž„๊ณ„๊ฐ’ ํ•„ํ„ฐ๋ง
filtered_results = [r for r in results if r["similarityScore"] >= threshold]
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ์ž„๊ณ„๊ฐ’({threshold}) ์ด์ƒ ๊ฒฐ๊ณผ: {len(filtered_results)}๊ฐœ / ์ „์ฒด {len(results)}๊ฐœ")
if len(filtered_results) > 0:
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ๊ฐ€์žฅ ๋†’์€ ์œ ์‚ฌ๋„ ์ ์ˆ˜: {filtered_results[0]['similarityScore']}")
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ์ƒ์œ„ ๊ฒฐ๊ณผ ์ฑ—ID: {filtered_results[0].get('chatId')}, ์‹œ์ž‘์‹œ๊ฐ„: {filtered_results[0].get('startTime')}")
return filtered_results
except Exception as e:
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๋‚ ์งœ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
# Gradio Blocks์— ํ•จ์ˆ˜ ๋“ฑ๋ก
with gr.Blocks() as demo:
gr.Markdown("# Chat Analysis Search")
gr.Interface(fn=search_similar_chat, inputs=["text", "number"], outputs="json", api_name="search_similar_chat")
gr.Interface(fn=search_similar_chat_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chat_by_date")
if __name__ == "__main__":
demo.launch(mcp_server=True)