Spaces:
Paused
Paused
| import gradio as gr | |
| import psycopg2 | |
| from openai import OpenAI | |
| import json | |
| import os | |
| from typing import List, Dict, Tuple, Any | |
| from pgvector.psycopg2 import register_vector | |
| import numpy as np | |
| from datetime import datetime | |
| from sklearn.preprocessing import normalize | |
| # ๊ฐ์ค์น ๋ฐ ์๊ณ๊ฐ ์ค์ | |
| DEFAULT_FULL_WEIGHT = 0.2 | |
| DEFAULT_TOPIC_WEIGHT = 0.5 | |
| DEFAULT_CUSTOMER_WEIGHT = 0.2 | |
| DEFAULT_AGENT_WEIGHT = 0.1 | |
| DEFAULT_SIMILARITY_THRESHOLD = 0 | |
| # DB ์ฐ๊ฒฐ ์ค์ | |
| def get_db_conn(): | |
| return psycopg2.connect( | |
| host=os.environ["VECTOR_HOST"], | |
| port=5432, | |
| dbname=os.environ["VECTOR_DBNAME"], | |
| user=os.environ["VECTOR_USER"], | |
| password=os.environ["VECTOR_SECRET"] | |
| ) | |
| client = OpenAI() | |
| def get_embedding(text: str) -> List[float]: | |
| """ํ ์คํธ๋ฅผ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํํฉ๋๋ค.""" | |
| response = client.embeddings.create( | |
| input=text, | |
| model="text-embedding-3-small" | |
| ) | |
| return response.data[0].embedding | |
| def get_text_value(node, field_name): | |
| """JSON ๋ ธ๋์์ ํ ์คํธ ๊ฐ์ ์์ ํ๊ฒ ์ถ์ถํฉ๋๋ค.""" | |
| if node and field_name in node and node[field_name] is not None: | |
| return node[field_name] | |
| return None | |
| def format_vector_for_pg(vector: List[float]) -> str: | |
| """๋ฒกํฐ๋ฅผ PostgreSQL ํฌ๋งท์ผ๋ก ๋ณํํฉ๋๋ค.""" | |
| return f"[{','.join(str(x) for x in vector)}]" | |
| def search_similar_chat(query: str, max_results: int = 100) -> List[Dict]: | |
| """ | |
| ๋ค์ค ์๋ฒ ๋ฉ๋ ์ฑํ ๋ฐ์ดํฐ์์ ์ ์ฌํ ์ฝํ ์ธ ๋ฅผ ๊ฒ์ํฉ๋๋ค. | |
| Args: | |
| query (str): ๊ฒ์ํ ์ฟผ๋ฆฌ ํ ์คํธ | |
| max_results (int): ๋ฐํํ ์ต๋ ๊ฒฐ๊ณผ ์ | |
| Returns: | |
| List[Dict]: ๊ฒ์ ๊ฒฐ๊ณผ ๋ชฉ๋ก | |
| """ | |
| limit = max_results if max_results is not None else 100 | |
| # ๊ฐ์ค์น ์ค์ | |
| full_w = DEFAULT_FULL_WEIGHT | |
| topic_w = DEFAULT_TOPIC_WEIGHT | |
| customer_w = DEFAULT_CUSTOMER_WEIGHT | |
| agent_w = DEFAULT_AGENT_WEIGHT | |
| threshold = DEFAULT_SIMILARITY_THRESHOLD | |
| print(f"๋ค์ค ์๋ฒ ๋ฉ ๊ฒ์ ์์: ์ฟผ๋ฆฌ='{query}', ๊ฐ์ค์น=(full={full_w}, topic={topic_w}, customer={customer_w}, agent={agent_w}), ์ต๋ ๊ฒฐ๊ณผ={limit}") | |
| try: | |
| # ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ ์์ฑ | |
| raw_embedding = np.array(get_embedding(query)) | |
| # L2 ์ ๊ทํ ์ ์ฉ | |
| query_embedding = normalize(raw_embedding.reshape(1, -1), norm='l2')[0] | |
| print(f"์๋ฒ ๋ฉ ์ ๊ทํ ์ /ํ ์ฒซ 5๊ฐ ์์: {raw_embedding[:5]} -> {query_embedding[:5]}") | |
| # Java ๋ฐฉ์: ๋ฒกํฐ๋ฅผ ๋ฌธ์์ด๋ก ๋ณํ | |
| query_vector = format_vector_for_pg(query_embedding) | |
| # DB ์ฐ๊ฒฐ | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| # Java ๋ฐฉ์: ๋ฌธ์์ด ํฌ๋งทํ ์ฌ์ฉํ SQL ์ฟผ๋ฆฌ | |
| sql = f""" | |
| WITH embeddings AS ( | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| CASE WHEN full_embedding IS NOT NULL THEN (full_embedding <=> '{query_vector}'::vector) ELSE 0 END * {full_w} as full_sim, | |
| CASE WHEN topic_embedding IS NOT NULL THEN (topic_embedding <=> '{query_vector}'::vector) ELSE 0 END * {topic_w} as topic_sim, | |
| CASE WHEN customer_embedding IS NOT NULL THEN (customer_embedding <=> '{query_vector}'::vector) ELSE 0 END * {customer_w} as customer_sim, | |
| CASE WHEN agent_embedding IS NOT NULL THEN (agent_embedding <=> '{query_vector}'::vector) ELSE 0 END * {agent_w} as agent_sim | |
| FROM vector_store_multi_embeddings | |
| WHERE full_embedding IS NOT NULL | |
| OR topic_embedding IS NOT NULL | |
| OR customer_embedding IS NOT NULL | |
| OR agent_embedding IS NOT NULL | |
| ) | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity | |
| FROM embeddings | |
| ORDER BY combined_similarity DESC | |
| LIMIT {limit} | |
| """ | |
| with conn.cursor() as cur: | |
| print(f"์ฟผ๋ฆฌ ์คํ - Java ๋ฐฉ์ ํฌ๋งทํ , ๊ฐ์ค์น ์ค์ ={full_w}, {topic_w}, {customer_w}, {agent_w}, ๊ฒฐ๊ณผ ์ ํ={limit}") | |
| # Java ๋ฐฉ์: ๋งค๊ฐ๋ณ์ ์์ด ์ง์ ์ฟผ๋ฆฌ ์คํ | |
| cur.execute(sql) | |
| rows = cur.fetchall() | |
| print(f"๊ฒ์ ๊ฒฐ๊ณผ: ์ด {len(rows)}๊ฐ ๋ฐ์ดํฐ ์กฐํ๋จ") | |
| if len(rows) > 0: | |
| print(f"์ฒซ ๋ฒ์งธ ๊ฒฐ๊ณผ ID: {rows[0][0]}, ์ ์ฌ๋: {float(rows[0][3])}") | |
| results = [] | |
| for row in rows: | |
| id_val = row[0] | |
| metadata_json = row[1] | |
| content = row[2] | |
| similarity_score = float(row[3]) | |
| # ๋ฉํ๋ฐ์ดํฐ ํ์ฑ | |
| try: | |
| metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json | |
| result = { | |
| "id": id_val, | |
| "similarityScore": similarity_score, | |
| "content": content, | |
| "chatId": get_text_value(metadata, "chatId"), | |
| "topic": get_text_value(metadata, "topic") | |
| } | |
| # ์๊ฐ ํ๋ ๋ณํ ์์ด ๊ทธ๋๋ก ์ฌ์ฉ | |
| if "startTime" in metadata and metadata["startTime"] is not None: | |
| result["startTime"] = metadata["startTime"] | |
| if "endTime" in metadata and metadata["endTime"] is not None: | |
| result["endTime"] = metadata["endTime"] | |
| results.append(result) | |
| except Exception as e: | |
| print(f"๋ฉํ๋ฐ์ดํฐ ํ์ฑ ์ค๋ฅ: {e}") | |
| print(f"๋ฌธ์ ๊ฐ ๋ฐ์ํ ๋ฉํ๋ฐ์ดํฐ: {metadata_json[:200]}...") | |
| continue | |
| # ์๊ณ๊ฐ ํํฐ๋ง | |
| filtered_results = [r for r in results if r["similarityScore"] >= threshold] | |
| print(f"์๊ณ๊ฐ({threshold}) ์ด์ ๊ฒฐ๊ณผ: {len(filtered_results)}๊ฐ / ์ ์ฒด {len(results)}๊ฐ") | |
| if len(filtered_results) > 0: | |
| print(f"๊ฐ์ฅ ๋์ ์ ์ฌ๋ ์ ์: {filtered_results[0]['similarityScore']}") | |
| print(f"์์ ๊ฒฐ๊ณผ ์ฑID: {filtered_results[0].get('chatId')}, ์ฃผ์ : {filtered_results[0].get('topic', '')[:50]}...") | |
| return filtered_results | |
| except Exception as e: | |
| print(f"๋ค์ค ์๋ฒ ๋ฉ ๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| return [] | |
| finally: | |
| if 'conn' in locals(): | |
| conn.close() | |
| def search_similar_chat_by_date( | |
| query: str, | |
| start_date: str = None, | |
| end_date: str = None, | |
| max_results: int = 100 | |
| ) -> List[Dict]: | |
| """ | |
| ์ง์ ๋ ๋ ์ง ๋ฒ์ ๋ด์ ๋ค์ค ์๋ฒ ๋ฉ ์ฑํ ๋ฐ์ดํฐ๋ฅผ ๊ฒ์ํฉ๋๋ค. | |
| Args: | |
| query (str): ๊ฒ์ํ ์ฟผ๋ฆฌ ํ ์คํธ | |
| start_date (str): ๊ฒ์ ์์ ๋ ์ง (YYYY-MM-DD ํ์) | |
| end_date (str): ๊ฒ์ ์ข ๋ฃ ๋ ์ง (YYYY-MM-DD ํ์) | |
| max_results (int): ๋ฐํํ ์ต๋ ๊ฒฐ๊ณผ ์ | |
| Returns: | |
| List[Dict]: ๊ฒ์ ๊ฒฐ๊ณผ ๋ชฉ๋ก | |
| """ | |
| limit = max_results if max_results is not None else 100 | |
| # ๊ฐ์ค์น ์ค์ | |
| full_w = DEFAULT_FULL_WEIGHT | |
| topic_w = DEFAULT_TOPIC_WEIGHT | |
| customer_w = DEFAULT_CUSTOMER_WEIGHT | |
| agent_w = DEFAULT_AGENT_WEIGHT | |
| threshold = DEFAULT_SIMILARITY_THRESHOLD | |
| print(f"๋ค์ค ์๋ฒ ๋ฉ ๋ ์ง ๊ฒ์ ์์: ์ฟผ๋ฆฌ='{query}', ์์์ผ={start_date}, ์ข ๋ฃ์ผ={end_date}, ์ต๋ ๊ฒฐ๊ณผ={limit}") | |
| try: | |
| # ๋ ์ง ํํฐ ํ๋ผ๋ฏธํฐ ์์ฑ | |
| start_timestamp = None | |
| end_timestamp = None | |
| if start_date and start_date.strip(): | |
| try: | |
| start_datetime = datetime.strptime(start_date, '%Y-%m-%d') | |
| start_timestamp = int(start_datetime.timestamp() * 1000) # ๋ฐ๋ฆฌ์ด ๋จ์๋ก ๋ณํ | |
| except ValueError as e: | |
| print(f"์์ ๋ ์ง ํ์ ์ค๋ฅ: {str(e)}") | |
| return [] | |
| if end_date and end_date.strip(): | |
| try: | |
| # ์ข ๋ฃ์ผ์ 23:59:59๋ก ์ค์ | |
| end_datetime = datetime.strptime(end_date + ' 23:59:59', '%Y-%m-%d %H:%M:%S') | |
| end_timestamp = int(end_datetime.timestamp() * 1000) # ๋ฐ๋ฆฌ์ด ๋จ์๋ก ๋ณํ | |
| except ValueError as e: | |
| print(f"์ข ๋ฃ ๋ ์ง ํ์ ์ค๋ฅ: {str(e)}") | |
| return [] | |
| # ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ ์์ฑ | |
| raw_embedding = np.array(get_embedding(query)) | |
| # L2 ์ ๊ทํ ์ ์ฉ | |
| query_embedding = normalize(raw_embedding.reshape(1, -1), norm='l2')[0] | |
| print(f"๋ ์ง ๊ฒ์ - ์๋ฒ ๋ฉ ์ ๊ทํ ์ /ํ ์ฒซ 5๊ฐ ์์: {raw_embedding[:5]} -> {query_embedding[:5]}") | |
| # Java ๋ฐฉ์: ๋ฒกํฐ๋ฅผ ๋ฌธ์์ด๋ก ๋ณํ | |
| query_vector = format_vector_for_pg(query_embedding) | |
| # DB ์ฐ๊ฒฐ | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| # Java ๋ฐฉ์: ๋ฌธ์์ด ํฌ๋งทํ ์ฌ์ฉํ SQL ์ฟผ๋ฆฌ ์์ | |
| sql = f""" | |
| WITH embeddings AS ( | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| CASE WHEN full_embedding IS NOT NULL THEN (full_embedding <=> '{query_vector}'::vector) ELSE 0 END * {full_w} as full_sim, | |
| CASE WHEN topic_embedding IS NOT NULL THEN (topic_embedding <=> '{query_vector}'::vector) ELSE 0 END * {topic_w} as topic_sim, | |
| CASE WHEN customer_embedding IS NOT NULL THEN (customer_embedding <=> '{query_vector}'::vector) ELSE 0 END * {customer_w} as customer_sim, | |
| CASE WHEN agent_embedding IS NOT NULL THEN (agent_embedding <=> '{query_vector}'::vector) ELSE 0 END * {agent_w} as agent_sim | |
| FROM vector_store_multi_embeddings | |
| WHERE full_embedding IS NOT NULL | |
| OR topic_embedding IS NOT NULL | |
| OR customer_embedding IS NOT NULL | |
| OR agent_embedding IS NOT NULL | |
| """ | |
| # ๋ ์ง ํํฐ ์ถ๊ฐ | |
| if start_timestamp is not None: | |
| sql += f" AND (metadata->>'startTime')::bigint >= {start_timestamp}" | |
| if end_timestamp is not None: | |
| sql += f" AND (metadata->>'startTime')::bigint <= {end_timestamp}" | |
| sql += """ | |
| ) | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity | |
| FROM embeddings | |
| ORDER BY combined_similarity DESC | |
| LIMIT %s | |
| """ | |
| with conn.cursor() as cur: | |
| print(f"๋ ์ง ๊ฒ์ ์ฟผ๋ฆฌ ์คํ: ์์์ผ={start_date}({start_timestamp}), ์ข ๋ฃ์ผ={end_date}({end_timestamp})") | |
| # ์ฌ๊ธฐ์๋ limit๋ง ๋งค๊ฐ๋ณ์๋ก ์ ๋ฌ | |
| cur.execute(sql, (limit,)) | |
| rows = cur.fetchall() | |
| print(f"๋ ์ง ํํฐ๋ง ๊ฒ์ ๊ฒฐ๊ณผ: ์ด {len(rows)}๊ฐ ๋ฐ์ดํฐ ์กฐํ๋จ") | |
| if len(rows) > 0: | |
| print(f"์ฒซ ๋ฒ์งธ ๊ฒฐ๊ณผ ID: {rows[0][0]}, ์ ์ฌ๋: {float(rows[0][3])}") | |
| results = [] | |
| for row in rows: | |
| id_val = row[0] | |
| metadata_json = row[1] | |
| content = row[2] | |
| similarity_score = float(row[3]) | |
| # ๋ฉํ๋ฐ์ดํฐ ํ์ฑ | |
| try: | |
| metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json | |
| result = { | |
| "id": id_val, | |
| "similarityScore": similarity_score, | |
| "content": content, | |
| "chatId": get_text_value(metadata, "chatId"), | |
| "topic": get_text_value(metadata, "topic") | |
| } | |
| # ์๊ฐ ํ๋ ๋ณํ ์์ด ๊ทธ๋๋ก ์ฌ์ฉ | |
| if "startTime" in metadata and metadata["startTime"] is not None: | |
| result["startTime"] = metadata["startTime"] | |
| if "endTime" in metadata and metadata["endTime"] is not None: | |
| result["endTime"] = metadata["endTime"] | |
| results.append(result) | |
| except Exception as e: | |
| print(f"๋ฉํ๋ฐ์ดํฐ ํ์ฑ ์ค๋ฅ: {e}") | |
| print(f"๋ฌธ์ ๊ฐ ๋ฐ์ํ ๋ฉํ๋ฐ์ดํฐ: {metadata_json[:200]}...") | |
| continue | |
| # ์๊ณ๊ฐ ํํฐ๋ง | |
| filtered_results = [r for r in results if r["similarityScore"] >= threshold] | |
| print(f"๋ ์ง ๊ฒ์ - ์๊ณ๊ฐ({threshold}) ์ด์ ๊ฒฐ๊ณผ: {len(filtered_results)}๊ฐ / ์ ์ฒด {len(results)}๊ฐ") | |
| if len(filtered_results) > 0: | |
| print(f"๋ ์ง ๊ฒ์ - ๊ฐ์ฅ ๋์ ์ ์ฌ๋ ์ ์: {filtered_results[0]['similarityScore']}") | |
| print(f"๋ ์ง ๊ฒ์ - ์์ ๊ฒฐ๊ณผ ์ฑID: {filtered_results[0].get('chatId')}, ์์์๊ฐ: {filtered_results[0].get('startTime')}") | |
| return filtered_results | |
| except Exception as e: | |
| print(f"๋ค์ค ์๋ฒ ๋ฉ ๋ ์ง ๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| return [] | |
| finally: | |
| if 'conn' in locals(): | |
| conn.close() | |
| # Gradio Blocks์ ํจ์ ๋ฑ๋ก | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Chat Analysis Search") | |
| gr.Interface(fn=search_similar_chat, inputs=["text", "number"], outputs="json", api_name="search_similar_chat") | |
| gr.Interface(fn=search_similar_chat_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chat_by_date") | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) |