ChatInsight / app.py
Jake-seong's picture
Update app.py
ab044f6 verified
raw
history blame
14.5 kB
import gradio as gr
import psycopg2
from openai import OpenAI
import json
import os
from typing import List, Dict
from pgvector.psycopg2 import register_vector
import numpy as np
# ๊ฐ€์ค‘์น˜ ๋ฐ ์ž„๊ณ„๊ฐ’ ์„ค์ •
DEFAULT_FULL_WEIGHT = 0.2
DEFAULT_TOPIC_WEIGHT = 0.5
DEFAULT_CUSTOMER_WEIGHT = 0.2
DEFAULT_AGENT_WEIGHT = 0.1
DEFAULT_SIMILARITY_THRESHOLD = 0
# OpenAI ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
client = OpenAI()
# DB ์—ฐ๊ฒฐ ์„ค์ •
def get_db_conn():
"""PostgreSQL ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์—ฐ๊ฒฐํ•ฉ๋‹ˆ๋‹ค."""
return psycopg2.connect(
host=os.environ["VECTOR_HOST"],
port=5432,
dbname=os.environ["VECTOR_DBNAME"],
user=os.environ["VECTOR_USER"],
password=os.environ["VECTOR_SECRET"]
)
def get_embedding(text: str) -> List[float]:
"""
ํ…์ŠคํŠธ๋ฅผ OpenAI์˜ text-embedding-3-small ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Java์˜ float[](float32)์™€ ํ˜ธํ™˜๋˜๋„๋ก ๋ช…์‹œ์ ์œผ๋กœ float32๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Args:
text (str): ์ž„๋ฒ ๋”ฉํ•  ํ…์ŠคํŠธ
Returns:
List[float]: ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ (float32)
"""
try:
response = client.embeddings.create(
input=text,
model="text-embedding-ada-002",
encoding_format="float"
)
# ๋ช…์‹œ์ ์œผ๋กœ float32๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ Java์˜ float[]์™€ ํ˜ธํ™˜๋˜๊ฒŒ ํ•จ
return np.array(response.data[0].embedding, dtype=np.float32).tolist()
except Exception as e:
print(f"์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
raise
def format_vector_for_pg(vector: List[float]) -> str:
"""
์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋ฅผ PostgreSQL ํฌ๋งท์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
Java์˜ formatVectorForPg() ๋ฉ”์†Œ๋“œ์™€ ๋™์ผํ•œ ๊ธฐ๋Šฅ์ž…๋‹ˆ๋‹ค.
์ž…๋ ฅ๋œ ๋ฒกํ„ฐ๊ฐ€ float32 ํƒ€์ž…์ธ์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
"""
# ๋ฒกํ„ฐ๊ฐ€ float32 ํƒ€์ž…์ธ์ง€ ํ™•์ธํ•˜๊ณ , ์•„๋‹ˆ๋ฉด ๋ณ€ํ™˜
# NumPy ๋ฐฐ์—ด์ด ์•„๋‹Œ ๊ฒฝ์šฐ์—๋„ ์ฒ˜๋ฆฌ
if not isinstance(vector, np.ndarray):
vector = np.array(vector, dtype=np.float32)
elif vector.dtype != np.float32:
vector = vector.astype(np.float32)
# ์ž๋ฐ” ๊ตฌํ˜„๊ณผ ๋™์ผํ•˜๊ฒŒ StringBuilder ๋ฐฉ์‹์œผ๋กœ ๊ตฌํ˜„
vector_str = ','.join([f"{x}" for x in vector])
return f"[{vector_str}]"
def get_text_value(node: Dict, field_name: str) -> str:
"""
๋”•์…”๋„ˆ๋ฆฌ์—์„œ ํ…์ŠคํŠธ ๊ฐ’์„ ์•ˆ์ „ํ•˜๊ฒŒ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
์ž๋ฐ”์˜ getTextValue() ๋ฉ”์†Œ๋“œ์™€ ๋™์ผํ•œ ๊ธฐ๋Šฅ์ž…๋‹ˆ๋‹ค.
"""
if node and field_name in node and node[field_name] is not None:
return node[field_name]
return None
def search_similar_chat(query: str, max_results: int = 100) -> List[Dict]:
"""
์ฑ„ํŒ… ๋ฐ์ดํ„ฐ์—์„œ ์œ ์‚ฌํ•œ ์ฝ˜ํ…์ธ ๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.
Args:
query (str): ๊ฒ€์ƒ‰ํ•  ์ฟผ๋ฆฌ ํ…์ŠคํŠธ
max_results (int): ๋ฐ˜ํ™˜ํ•  ์ตœ๋Œ€ ๊ฒฐ๊ณผ ์ˆ˜
Returns:
List[Dict]: ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ชฉ๋ก
"""
limit = max_results if max_results is not None else 100
# ์ž๋ฐ”์™€ ๋™์ผํ•œ ๊ฐ€์ค‘์น˜ ์„ค์ •
full_w = DEFAULT_FULL_WEIGHT
topic_w = DEFAULT_TOPIC_WEIGHT
customer_w = DEFAULT_CUSTOMER_WEIGHT
agent_w = DEFAULT_AGENT_WEIGHT
threshold = DEFAULT_SIMILARITY_THRESHOLD
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰ ์‹œ์ž‘: ์ฟผ๋ฆฌ='{query}', ๊ฐ€์ค‘์น˜=(full={full_w}, topic={topic_w}, customer={customer_w}, agent={agent_w}), ์ตœ๋Œ€ ๊ฒฐ๊ณผ={limit}")
try:
# ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
query_embedding = get_embedding(query)
# PostgreSQL ํฌ๋งท์œผ๋กœ ๋ฒกํ„ฐ ๋ณ€ํ™˜
query_vector = format_vector_for_pg(query_embedding)
# DB ์—ฐ๊ฒฐ
conn = get_db_conn()
register_vector(conn)
# ์ž๋ฐ” ์ฝ”๋“œ์™€ ๋™์ผํ•œ SQL ์ฟผ๋ฆฌ ๊ตฌํ˜„
sql = """
WITH embeddings AS (
SELECT
id,
metadata,
content,
CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim,
CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim,
CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim,
CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim
FROM vector_store_multi_embeddings
WHERE full_embedding IS NOT NULL
OR topic_embedding IS NOT NULL
OR customer_embedding IS NOT NULL
OR agent_embedding IS NOT NULL
)
SELECT
id,
metadata,
content,
(full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity
FROM embeddings
ORDER BY combined_similarity DESC
LIMIT %s
""" % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w, limit)
with conn.cursor() as cur:
print(f"์ฟผ๋ฆฌ ์‹คํ–‰: ์ž๋ฐ” ๊ตฌํ˜„๊ณผ ๋™์ผํ•˜๊ฒŒ ์ˆ˜์ •")
cur.execute(sql)
rows = cur.fetchall()
print(f"๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ: ์ด {len(rows)}๊ฐœ ๋ฐ์ดํ„ฐ ์กฐํšŒ๋จ")
if len(rows) > 0:
print(f"์ฒซ ๋ฒˆ์งธ ๊ฒฐ๊ณผ ID: {rows[0][0]}, ์œ ์‚ฌ๋„: {float(rows[0][3])}")
results = []
for row in rows:
id_val = row[0]
metadata_json = row[1]
content = row[2]
similarity_score = float(row[3])
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ
try:
metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json
result = {
"id": id_val,
"similarityScore": similarity_score,
"content": content,
"chatId": get_text_value(metadata, "chatId"),
"topic": get_text_value(metadata, "topic")
}
# ์‹œ๊ฐ„ ํ•„๋“œ ๋ณ€ํ™˜ ์—†์ด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
if "startTime" in metadata and metadata["startTime"] is not None:
result["startTime"] = metadata["startTime"]
if "endTime" in metadata and metadata["endTime"] is not None:
result["endTime"] = metadata["endTime"]
results.append(result)
except Exception as e:
print(f"๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ ์˜ค๋ฅ˜: {e}")
print(f"๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ: {metadata_json[:200]}...")
continue
# ์ž„๊ณ„๊ฐ’ ํ•„ํ„ฐ๋ง
filtered_results = [r for r in results if r["similarityScore"] >= threshold]
if len(filtered_results) > 0:
print(f"์ž„๊ณ„๊ฐ’({threshold}) ์ด์ƒ ๊ฒฐ๊ณผ: {len(filtered_results)}๊ฐœ / ์ „์ฒด {len(results)}๊ฐœ")
print(f"๊ฐ€์žฅ ๋†’์€ ์œ ์‚ฌ๋„ ์ ์ˆ˜: {filtered_results[0]['similarityScore']}")
print(f"์ƒ์œ„ ๊ฒฐ๊ณผ ์ฑ—ID: {filtered_results[0].get('chatId')}, ์ฃผ์ œ: {filtered_results[0].get('topic', '')[:50]}...")
else:
print(f"์ž„๊ณ„๊ฐ’({threshold}) ์ด์ƒ์˜ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
return filtered_results
except Exception as e:
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
def search_similar_chat_by_date(
query: str,
start_date: str = None,
end_date: str = None,
max_results: int = 100
) -> List[Dict]:
"""
์ง€์ •๋œ ๋‚ ์งœ ๋ฒ”์œ„ ๋‚ด์˜ ์ฑ„ํŒ… ๋ฐ์ดํ„ฐ๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค.
Args:
query (str): ๊ฒ€์ƒ‰ํ•  ์ฟผ๋ฆฌ ํ…์ŠคํŠธ
start_date (str): ๊ฒ€์ƒ‰ ์‹œ์ž‘ ๋‚ ์งœ (YYYY-MM-DD ํ˜•์‹)
end_date (str): ๊ฒ€์ƒ‰ ์ข…๋ฃŒ ๋‚ ์งœ (YYYY-MM-DD ํ˜•์‹)
max_results (int): ๋ฐ˜ํ™˜ํ•  ์ตœ๋Œ€ ๊ฒฐ๊ณผ ์ˆ˜
Returns:
List[Dict]: ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ๋ชฉ๋ก
"""
limit = max_results if max_results is not None else 100
# ์ž๋ฐ”์™€ ๋™์ผํ•œ ๊ฐ€์ค‘์น˜ ์„ค์ •
full_w = DEFAULT_FULL_WEIGHT
topic_w = DEFAULT_TOPIC_WEIGHT
customer_w = DEFAULT_CUSTOMER_WEIGHT
agent_w = DEFAULT_AGENT_WEIGHT
threshold = DEFAULT_SIMILARITY_THRESHOLD
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๋‚ ์งœ ๊ฒ€์ƒ‰ ์‹œ์ž‘: ์ฟผ๋ฆฌ='{query}', ์‹œ์ž‘์ผ={start_date}, ์ข…๋ฃŒ์ผ={end_date}, ์ตœ๋Œ€ ๊ฒฐ๊ณผ={limit}")
try:
# ๋‚ ์งœ ํ•„ํ„ฐ ์ƒ์„ฑ
# ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
query_embedding = get_embedding(query)
# PostgreSQL ํฌ๋งท์œผ๋กœ ๋ฒกํ„ฐ ๋ณ€ํ™˜
query_vector = format_vector_for_pg(query_embedding)
# DB ์—ฐ๊ฒฐ
conn = get_db_conn()
register_vector(conn)
# ์ž๋ฐ” ์ฝ”๋“œ์™€ ๋™์ผํ•œ SQL ์ฟผ๋ฆฌ ์‹œ์ž‘
sql = """
WITH embeddings AS (
SELECT
id,
metadata,
content,
CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim,
CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim,
CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim,
CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim
FROM vector_store_multi_embeddings
WHERE full_embedding IS NOT NULL
OR topic_embedding IS NOT NULL
OR customer_embedding IS NOT NULL
OR agent_embedding IS NOT NULL
""" % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w)
# ๋‚ ์งœ ํ•„ํ„ฐ ์ถ”๊ฐ€
if start_date and start_date.strip():
# ์‹œ์ž‘ ์‹œ๊ฐ„ ์ถ”๊ฐ€ํ•˜์—ฌ ISO ํ˜•์‹์œผ๋กœ ๋น„๊ต
iso_start_date = start_date + "T00:00:00"
sql += f" AND metadata->>'startTime' >= '{iso_start_date}'"
if end_date and end_date.strip():
# ์ข…๋ฃŒ ์‹œ๊ฐ„ ์ถ”๊ฐ€ํ•˜์—ฌ ISO ํ˜•์‹์œผ๋กœ ๋น„๊ต
iso_end_date = end_date + "T23:59:59"
sql += f" AND metadata->>'startTime' <= '{iso_end_date}'"
sql += """
)
SELECT
id,
metadata,
content,
(full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity
FROM embeddings
ORDER BY combined_similarity DESC
LIMIT %s
"""
with conn.cursor() as cur:
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ ์‹คํ–‰: ์‹œ์ž‘์ผ={start_date}, ์ข…๋ฃŒ์ผ={end_date}")
# ์—ฌ๊ธฐ์„œ๋Š” limit๋ฅผ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ์ „๋‹ฌ
cur.execute(sql, (limit,))
rows = cur.fetchall()
print(f"๋‚ ์งœ ํ•„ํ„ฐ๋ง ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ: ์ด {len(rows)}๊ฐœ ๋ฐ์ดํ„ฐ ์กฐํšŒ๋จ")
if len(rows) > 0:
print(f"์ฒซ ๋ฒˆ์งธ ๊ฒฐ๊ณผ ID: {rows[0][0]}, ์œ ์‚ฌ๋„: {float(rows[0][3])}")
results = []
for row in rows:
id_val = row[0]
metadata_json = row[1]
content = row[2]
similarity_score = float(row[3])
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ
try:
metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json
result = {
"id": id_val,
"similarityScore": similarity_score,
"content": content,
"chatId": get_text_value(metadata, "chatId"),
"topic": get_text_value(metadata, "topic")
}
# ์‹œ๊ฐ„ ํ•„๋“œ ๋ณ€ํ™˜ ์—†์ด ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ (์ด๋ฏธ KST๋กœ ์ €์žฅ๋˜์–ด ์žˆ์Œ)
if "startTime" in metadata and metadata["startTime"] is not None:
result["startTime"] = metadata["startTime"]
if "endTime" in metadata and metadata["endTime"] is not None:
result["endTime"] = metadata["endTime"]
results.append(result)
except Exception as e:
print(f"๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์‹ฑ ์˜ค๋ฅ˜: {e}")
print(f"๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ: {metadata_json[:200]}...")
continue
# ์ž„๊ณ„๊ฐ’ ํ•„ํ„ฐ๋ง (์ž๋ฐ” ์ฝ”๋“œ์™€ ๋™์ผํ•˜๊ฒŒ ๊ตฌํ˜„)
filtered_results = [r for r in results if r["similarityScore"] >= threshold]
if len(filtered_results) > 0:
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ์ž„๊ณ„๊ฐ’({threshold}) ์ด์ƒ ๊ฒฐ๊ณผ: {len(filtered_results)}๊ฐœ / ์ „์ฒด {len(results)}๊ฐœ")
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ๊ฐ€์žฅ ๋†’์€ ์œ ์‚ฌ๋„ ์ ์ˆ˜: {filtered_results[0]['similarityScore']}")
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ์ƒ์œ„ ๊ฒฐ๊ณผ ์ฑ—ID: {filtered_results[0].get('chatId')}, ์‹œ์ž‘์‹œ๊ฐ„: {filtered_results[0].get('startTime')}")
else:
print(f"๋‚ ์งœ ๊ฒ€์ƒ‰ - ์ž„๊ณ„๊ฐ’({threshold}) ์ด์ƒ์˜ ๊ฒฐ๊ณผ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค")
return filtered_results
except Exception as e:
print(f"๋‹ค์ค‘ ์ž„๋ฒ ๋”ฉ ๋‚ ์งœ ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
# Gradio ์›น ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
with gr.Blocks() as demo:
gr.Markdown("# ์ฑ„ํŒ… ๋ถ„์„ ๊ฒ€์ƒ‰")
gr.Interface(fn=search_similar_chat, inputs=["text", "number"], outputs="json", api_name="search_similar_chat")
gr.Interface(fn=search_similar_chat_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chat_by_date")
if __name__ == "__main__":
demo.launch(mcp_server=True)