Spaces:
Paused
Paused
| import gradio as gr | |
| import psycopg2 | |
| from openai import OpenAI | |
| import json | |
| import os | |
| from typing import List, Dict | |
| from pgvector.psycopg2 import register_vector | |
| import numpy as np | |
| # ๊ฐ์ค์น ๋ฐ ์๊ณ๊ฐ ์ค์ | |
| DEFAULT_FULL_WEIGHT = 0.2 | |
| DEFAULT_TOPIC_WEIGHT = 0.5 | |
| DEFAULT_CUSTOMER_WEIGHT = 0.2 | |
| DEFAULT_AGENT_WEIGHT = 0.1 | |
| DEFAULT_SIMILARITY_THRESHOLD = 0.5 | |
| # DB ์ฐ๊ฒฐ ์ค์ | |
| def get_db_conn(): | |
| return psycopg2.connect( | |
| host=os.environ["VECTOR_HOST"], | |
| port=5432, | |
| dbname=os.environ["VECTOR_DBNAME"], | |
| user=os.environ["VECTOR_USER"], | |
| password=os.environ["VECTOR_SECRET"] | |
| ) | |
| # OpenAI ํด๋ผ์ด์ธํธ ์ด๊ธฐํ | |
| client = OpenAI() | |
| def get_embedding(text: str) -> List[float]: | |
| """ | |
| ํ ์คํธ๋ฅผ OpenAI์ text-embedding-ada-002 ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ๋ณํํฉ๋๋ค. | |
| Java์ float[](float32)์ ํธํ๋๋๋ก ๋ช ์์ ์ผ๋ก float32๋ก ๋ณํํฉ๋๋ค. | |
| Args: | |
| text (str): ์๋ฒ ๋ฉํ ํ ์คํธ | |
| Returns: | |
| List[float]: ์๋ฒ ๋ฉ ๋ฒกํฐ (float32) | |
| """ | |
| try: | |
| response = client.embeddings.create( | |
| input=text, | |
| model="text-embedding-ada-002", | |
| encoding_format="float" | |
| ) | |
| # ๋ช ์์ ์ผ๋ก float32๋ก ๋ณํํ์ฌ Java์ float[]์ ํธํ๋๊ฒ ํจ | |
| return np.array(response.data[0].embedding, dtype=np.float32).tolist() | |
| except Exception as e: | |
| print(f"์๋ฒ ๋ฉ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| raise | |
| def format_vector_for_pg(vector: List[float]) -> str: | |
| """ | |
| ์๋ฒ ๋ฉ ๋ฒกํฐ๋ฅผ PostgreSQL ํฌ๋งท์ผ๋ก ๋ณํํฉ๋๋ค. | |
| ์ ๋ ฅ๋ ๋ฒกํฐ๊ฐ float32 ํ์ ์ธ์ง ํ์ธํฉ๋๋ค. | |
| """ | |
| # ๋ฒกํฐ๊ฐ float32 ํ์ ์ธ์ง ํ์ธํ๊ณ , ์๋๋ฉด ๋ณํ | |
| # NumPy ๋ฐฐ์ด์ด ์๋ ๊ฒฝ์ฐ์๋ ์ฒ๋ฆฌ | |
| if not isinstance(vector, np.ndarray): | |
| vector = np.array(vector, dtype=np.float32) | |
| elif vector.dtype != np.float32: | |
| vector = vector.astype(np.float32) | |
| vector_str = ','.join([f"{x}" for x in vector]) | |
| return f"[{vector_str}]" | |
| def get_text_value(node: Dict, field_name: str) -> str: | |
| """ | |
| ๋์ ๋๋ฆฌ์์ ํ ์คํธ ๊ฐ์ ์์ ํ๊ฒ ์ถ์ถํฉ๋๋ค. | |
| ์๋ฐ์ getTextValue() ๋ฉ์๋์ ๋์ผํ ๊ธฐ๋ฅ์ ๋๋ค. | |
| """ | |
| if node and field_name in node and node[field_name] is not None: | |
| return node[field_name] | |
| return None | |
| def search_similar_chat(query: str, max_results: int = 100) -> List[Dict]: | |
| """ | |
| ์ฑํ ๋ฐ์ดํฐ์์ ์ ์ฌํ ์ฝํ ์ธ ๋ฅผ ๊ฒ์ํฉ๋๋ค. | |
| Args: | |
| query (str): ๊ฒ์ํ ์ฟผ๋ฆฌ ํ ์คํธ | |
| max_results (int): ๋ฐํํ ์ต๋ ๊ฒฐ๊ณผ ์ | |
| Returns: | |
| List[Dict]: ๊ฒ์ ๊ฒฐ๊ณผ ๋ชฉ๋ก | |
| """ | |
| limit = max_results if max_results is not None else 100 | |
| # ์๋ฐ์ ๋์ผํ ๊ฐ์ค์น ์ค์ | |
| full_w = DEFAULT_FULL_WEIGHT | |
| topic_w = DEFAULT_TOPIC_WEIGHT | |
| customer_w = DEFAULT_CUSTOMER_WEIGHT | |
| agent_w = DEFAULT_AGENT_WEIGHT | |
| threshold = DEFAULT_SIMILARITY_THRESHOLD | |
| try: | |
| # ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ ์์ฑ | |
| query_embedding = get_embedding(query) | |
| # PostgreSQL ํฌ๋งท์ผ๋ก ๋ฒกํฐ ๋ณํ | |
| query_vector = format_vector_for_pg(query_embedding) | |
| # DB ์ฐ๊ฒฐ | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| # ์๋ฐ ์ฝ๋์ ๋์ผํ SQL ์ฟผ๋ฆฌ ๊ตฌํ | |
| sql = """ | |
| WITH embeddings AS ( | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim, | |
| CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim, | |
| CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim, | |
| CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim | |
| FROM vector_store_multi_embeddings | |
| WHERE full_embedding IS NOT NULL | |
| OR topic_embedding IS NOT NULL | |
| OR customer_embedding IS NOT NULL | |
| OR agent_embedding IS NOT NULL | |
| ) | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity | |
| FROM embeddings | |
| ORDER BY combined_similarity DESC | |
| LIMIT %s | |
| """ % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w, limit) | |
| with conn.cursor() as cur: | |
| cur.execute(sql) | |
| rows = cur.fetchall() | |
| results = [] | |
| for row in rows: | |
| id_val = row[0] | |
| metadata_json = row[1] | |
| content = row[2] | |
| similarity_score = float(row[3]) | |
| # ๋ฉํ๋ฐ์ดํฐ ํ์ฑ | |
| try: | |
| metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json | |
| result = { | |
| "id": id_val, | |
| "similarityScore": similarity_score, | |
| "content": content, | |
| "chatId": get_text_value(metadata, "chatId"), | |
| "topic": get_text_value(metadata, "topic") | |
| } | |
| # ์๊ฐ ํ๋ ์ฒ๋ฆฌ - ํ์์คํฌํ๋ฅผ ISO ํ์ ๋ฌธ์์ด๋ก ๋ณํ | |
| if "startTime" in metadata and metadata["startTime"] is not None: | |
| # PostgreSQL ํ์์คํฌํ ๋๋ ์ซ์์ผ ๊ฒฝ์ฐ ISO ๋ฌธ์์ด๋ก ๋ณํ | |
| start_time = metadata["startTime"] | |
| if isinstance(start_time, str): | |
| # ์ด๋ฏธ ๋ฌธ์์ด์ด๋ฉด ๊ทธ๋๋ก ์ฌ์ฉ | |
| result["startTime"] = start_time | |
| else: | |
| # ํ์์คํฌํ๋ ์ซ์์ธ ๊ฒฝ์ฐ ๋ฌธ์์ด๋ก ๋ณํ | |
| from datetime import datetime | |
| try: | |
| # ๋ฐ๋ฆฌ์ด ํ์์คํฌํ์ธ ๊ฒฝ์ฐ ์ฒ๋ฆฌ | |
| if isinstance(start_time, int) or (isinstance(start_time, str) and start_time.isdigit()): | |
| dt = datetime.fromtimestamp(int(start_time)/1000) | |
| else: | |
| # PostgreSQL ํ์์คํฌํ ๊ฐ์ฒด ์ฒ๋ฆฌ | |
| dt = datetime.fromisoformat(str(start_time).replace('Z', '+00:00')) | |
| result["startTime"] = dt.strftime('%Y-%m-%dT%H:%M:%S') | |
| except: | |
| # ๋ณํ ์คํจ์ ์๋ณธ ๊ฐ ์ฌ์ฉ | |
| result["startTime"] = start_time | |
| if "endTime" in metadata and metadata["endTime"] is not None: | |
| # startTime๊ณผ ๋์ผํ ๋ก์ง ์ ์ฉ | |
| end_time = metadata["endTime"] | |
| if isinstance(end_time, str): | |
| result["endTime"] = end_time | |
| else: | |
| from datetime import datetime | |
| try: | |
| if isinstance(end_time, int) or (isinstance(end_time, str) and end_time.isdigit()): | |
| dt = datetime.fromtimestamp(int(end_time)/1000) | |
| else: | |
| dt = datetime.fromisoformat(str(end_time).replace('Z', '+00:00')) | |
| result["endTime"] = dt.strftime('%Y-%m-%dT%H:%M:%S') | |
| except: | |
| result["endTime"] = end_time | |
| results.append(result) | |
| except Exception as e: | |
| print(f"๋ฉํ๋ฐ์ดํฐ ํ์ฑ ์ค๋ฅ: {e}") | |
| continue | |
| # ์๊ณ๊ฐ ํํฐ๋ง | |
| filtered_results = [r for r in results if r["similarityScore"] >= threshold] | |
| return filtered_results | |
| except Exception as e: | |
| print(f"๋ค์ค ์๋ฒ ๋ฉ ๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| return [] | |
| finally: | |
| if 'conn' in locals(): | |
| conn.close() | |
| def search_similar_chat_by_date( | |
| query: str, | |
| start_date: str = None, | |
| end_date: str = None, | |
| max_results: int = 100 | |
| ) -> List[Dict]: | |
| """ | |
| ์ง์ ๋ ๋ ์ง ๋ฒ์ ๋ด์ ์ฑํ ๋ฐ์ดํฐ๋ฅผ ๊ฒ์ํฉ๋๋ค. | |
| Args: | |
| query (str): ๊ฒ์ํ ์ฟผ๋ฆฌ ํ ์คํธ | |
| start_date (str): ๊ฒ์ ์์ ๋ ์ง (YYYY-MM-DD ํ์) | |
| end_date (str): ๊ฒ์ ์ข ๋ฃ ๋ ์ง (YYYY-MM-DD ํ์) | |
| max_results (int): ๋ฐํํ ์ต๋ ๊ฒฐ๊ณผ ์ | |
| Returns: | |
| List[Dict]: ๊ฒ์ ๊ฒฐ๊ณผ ๋ชฉ๋ก | |
| """ | |
| limit = max_results if max_results is not None else 100 | |
| # ์๋ฐ์ ๋์ผํ ๊ฐ์ค์น ์ค์ | |
| full_w = DEFAULT_FULL_WEIGHT | |
| topic_w = DEFAULT_TOPIC_WEIGHT | |
| customer_w = DEFAULT_CUSTOMER_WEIGHT | |
| agent_w = DEFAULT_AGENT_WEIGHT | |
| threshold = DEFAULT_SIMILARITY_THRESHOLD | |
| try: | |
| # ์ฟผ๋ฆฌ ์๋ฒ ๋ฉ ์์ฑ | |
| query_embedding = get_embedding(query) | |
| # PostgreSQL ํฌ๋งท์ผ๋ก ๋ฒกํฐ ๋ณํ | |
| query_vector = format_vector_for_pg(query_embedding) | |
| # DB ์ฐ๊ฒฐ | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| # ์๋ฐ ์ฝ๋์ ๋์ผํ SQL ์ฟผ๋ฆฌ ์์ | |
| sql = """ | |
| WITH embeddings AS ( | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim, | |
| CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim, | |
| CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim, | |
| CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim | |
| FROM vector_store_multi_embeddings | |
| WHERE (full_embedding IS NOT NULL | |
| OR topic_embedding IS NOT NULL | |
| OR customer_embedding IS NOT NULL | |
| OR agent_embedding IS NOT NULL) | |
| """ % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w) | |
| # ๋ ์ง ํํฐ ์ถ๊ฐ | |
| if start_date and start_date.strip(): | |
| # ์์ ์๊ฐ ์ถ๊ฐํ์ฌ ISO ํ์์ผ๋ก ๋น๊ต | |
| iso_start_date = start_date + "T00:00:00" | |
| sql += f" AND (metadata->>'startTime') >= '{iso_start_date}'" | |
| if end_date and end_date.strip(): | |
| # ์ข ๋ฃ ์๊ฐ ์ถ๊ฐํ์ฌ ISO ํ์์ผ๋ก ๋น๊ต | |
| iso_end_date = end_date + "T23:59:59" | |
| sql += f" AND (metadata->>'startTime') <= '{iso_end_date}'" | |
| sql += """ | |
| ) | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity | |
| FROM embeddings | |
| ORDER BY combined_similarity DESC | |
| LIMIT %s | |
| """ | |
| with conn.cursor() as cur: | |
| # ์ฌ๊ธฐ์๋ limit๋ฅผ ํ๋ผ๋ฏธํฐ๋ก ์ ๋ฌ | |
| cur.execute(sql, (limit,)) | |
| rows = cur.fetchall() | |
| results = [] | |
| for row in rows: | |
| id_val = row[0] | |
| metadata_json = row[1] | |
| content = row[2] | |
| similarity_score = float(row[3]) | |
| # ๋ฉํ๋ฐ์ดํฐ ํ์ฑ | |
| try: | |
| metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json | |
| result = { | |
| "id": id_val, | |
| "similarityScore": similarity_score, | |
| "content": content, | |
| "chatId": get_text_value(metadata, "chatId"), | |
| "topic": get_text_value(metadata, "topic") | |
| } | |
| # ์๊ฐ ํ๋ ๋ณํ ์์ด ๊ทธ๋๋ก ์ฌ์ฉ (์ด๋ฏธ KST๋ก ์ ์ฅ๋์ด ์์) | |
| if "startTime" in metadata and metadata["startTime"] is not None: | |
| result["startTime"] = metadata["startTime"] | |
| if "endTime" in metadata and metadata["endTime"] is not None: | |
| result["endTime"] = metadata["endTime"] | |
| results.append(result) | |
| except Exception as e: | |
| print(f"๋ฉํ๋ฐ์ดํฐ ํ์ฑ ์ค๋ฅ: {e}") | |
| continue | |
| # ์๊ณ๊ฐ ํํฐ๋ง (์๋ฐ ์ฝ๋์ ๋์ผํ๊ฒ ๊ตฌํ) | |
| filtered_results = [r for r in results if r["similarityScore"] >= threshold] | |
| return filtered_results | |
| except Exception as e: | |
| print(f"๋ค์ค ์๋ฒ ๋ฉ ๋ ์ง ๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| return [] | |
| finally: | |
| if 'conn' in locals(): | |
| conn.close() | |
| # Gradio ์น ์ธํฐํ์ด์ค ์ค์ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Chat Analysis Search") | |
| gr.Interface(fn=search_similar_chat, inputs=["text", "number"], outputs="json", api_name="search_similar_chat") | |
| gr.Interface(fn=search_similar_chat_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chat_by_date") | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) |