Spaces:
Paused
Paused
| import gradio as gr | |
| import psycopg2 | |
| from openai import OpenAI | |
| import json | |
| import os | |
| from typing import List, Dict | |
| from pgvector.psycopg2 import register_vector | |
| import numpy as np | |
| # κ°μ€μΉ λ° μκ³κ° μ€μ | |
| DEFAULT_FULL_WEIGHT = 0.2 | |
| DEFAULT_TOPIC_WEIGHT = 0.5 | |
| DEFAULT_CUSTOMER_WEIGHT = 0.2 | |
| DEFAULT_AGENT_WEIGHT = 0.1 | |
| DEFAULT_SIMILARITY_THRESHOLD = 0.5 | |
| # DB μ°κ²° μ€μ | |
| def get_db_conn(): | |
| return psycopg2.connect( | |
| host=os.environ["VECTOR_HOST"], | |
| port=5432, | |
| dbname=os.environ["VECTOR_DBNAME"], | |
| user=os.environ["VECTOR_USER"], | |
| password=os.environ["VECTOR_SECRET"] | |
| ) | |
| # OpenAI ν΄λΌμ΄μΈνΈ μ΄κΈ°ν | |
| client = OpenAI() | |
| def get_embedding(text: str) -> List[float]: | |
| """ | |
| ν μ€νΈλ₯Ό OpenAIμ text-embedding-ada-002 λͺ¨λΈμ μ¬μ©νμ¬ μλ² λ© λ²‘ν°λ‘ λ³νν©λλ€. | |
| Javaμ float[](float32)μ νΈνλλλ‘ λͺ μμ μΌλ‘ float32λ‘ λ³νν©λλ€. | |
| Args: | |
| text (str): μλ² λ©ν ν μ€νΈ | |
| Returns: | |
| List[float]: μλ² λ© λ²‘ν° (float32) | |
| """ | |
| try: | |
| response = client.embeddings.create( | |
| input=text, | |
| model="text-embedding-ada-002", | |
| encoding_format="float" | |
| ) | |
| # λͺ μμ μΌλ‘ float32λ‘ λ³ννμ¬ Javaμ float[]μ νΈνλκ² ν¨ | |
| return np.array(response.data[0].embedding, dtype=np.float32).tolist() | |
| except Exception as e: | |
| print(f"μλ² λ© μμ± μ€ μ€λ₯ λ°μ: {str(e)}") | |
| raise | |
| def format_vector_for_pg(vector: List[float]) -> str: | |
| """ | |
| μλ² λ© λ²‘ν°λ₯Ό PostgreSQL ν¬λ§·μΌλ‘ λ³νν©λλ€. | |
| μ λ ₯λ 벑ν°κ° float32 νμ μΈμ§ νμΈν©λλ€. | |
| """ | |
| # 벑ν°κ° float32 νμ μΈμ§ νμΈνκ³ , μλλ©΄ λ³ν | |
| # NumPy λ°°μ΄μ΄ μλ κ²½μ°μλ μ²λ¦¬ | |
| if not isinstance(vector, np.ndarray): | |
| vector = np.array(vector, dtype=np.float32) | |
| elif vector.dtype != np.float32: | |
| vector = vector.astype(np.float32) | |
| vector_str = ','.join([f"{x}" for x in vector]) | |
| return f"[{vector_str}]" | |
| def get_text_value(node: Dict, field_name: str) -> str: | |
| """ | |
| λμ λ리μμ ν μ€νΈ κ°μ μμ νκ² μΆμΆν©λλ€. | |
| μλ°μ getTextValue() λ©μλμ λμΌν κΈ°λ₯μ λλ€. | |
| """ | |
| if node and field_name in node and node[field_name] is not None: | |
| return node[field_name] | |
| return None | |
| def search_similar_chat(query: str, max_results: int = 100) -> List[Dict]: | |
| """ | |
| μ±ν λ°μ΄ν°μμ μ μ¬ν μ½ν μΈ λ₯Ό κ²μν©λλ€. | |
| Args: | |
| query (str): κ²μν 쿼리 ν μ€νΈ | |
| max_results (int): λ°νν μ΅λ κ²°κ³Ό μ | |
| Returns: | |
| List[Dict]: κ²μ κ²°κ³Ό λͺ©λ‘ | |
| """ | |
| limit = max_results if max_results is not None else 100 | |
| # μλ°μ λμΌν κ°μ€μΉ μ€μ | |
| full_w = DEFAULT_FULL_WEIGHT | |
| topic_w = DEFAULT_TOPIC_WEIGHT | |
| customer_w = DEFAULT_CUSTOMER_WEIGHT | |
| agent_w = DEFAULT_AGENT_WEIGHT | |
| threshold = DEFAULT_SIMILARITY_THRESHOLD | |
| try: | |
| # 쿼리 μλ² λ© μμ± | |
| query_embedding = get_embedding(query) | |
| # PostgreSQL ν¬λ§·μΌλ‘ λ²‘ν° λ³ν | |
| query_vector = format_vector_for_pg(query_embedding) | |
| # DB μ°κ²° | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| # μλ° μ½λμ λμΌν SQL 쿼리 ꡬν | |
| sql = """ | |
| WITH embeddings AS ( | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim, | |
| CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim, | |
| CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim, | |
| CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim | |
| FROM vector_store_multi_embeddings | |
| WHERE full_embedding IS NOT NULL | |
| OR topic_embedding IS NOT NULL | |
| OR customer_embedding IS NOT NULL | |
| OR agent_embedding IS NOT NULL | |
| ) | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity | |
| FROM embeddings | |
| ORDER BY combined_similarity DESC | |
| LIMIT %s | |
| """ % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w, limit) | |
| with conn.cursor() as cur: | |
| cur.execute(sql) | |
| rows = cur.fetchall() | |
| results = [] | |
| for row in rows: | |
| id_val = row[0] | |
| metadata_json = row[1] | |
| content = row[2] | |
| similarity_score = float(row[3]) | |
| # λ©νλ°μ΄ν° νμ± | |
| try: | |
| metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json | |
| result = { | |
| "id": id_val, | |
| "similarityScore": similarity_score, | |
| "content": content, | |
| "chatId": get_text_value(metadata, "chatId"), | |
| "topic": get_text_value(metadata, "topic") | |
| } | |
| # μκ° νλ λ³ν μμ΄ κ·Έλλ‘ μ¬μ© | |
| if "startTime" in metadata and metadata["startTime"] is not None: | |
| result["startTime"] = metadata["startTime"] | |
| if "endTime" in metadata and metadata["endTime"] is not None: | |
| result["endTime"] = metadata["endTime"] | |
| results.append(result) | |
| except Exception as e: | |
| print(f"λ©νλ°μ΄ν° νμ± μ€λ₯: {e}") | |
| continue | |
| # μκ³κ° νν°λ§ | |
| filtered_results = [r for r in results if r["similarityScore"] >= threshold] | |
| return filtered_results | |
| except Exception as e: | |
| print(f"λ€μ€ μλ² λ© κ²μ μ€ μ€λ₯ λ°μ: {str(e)}") | |
| return [] | |
| finally: | |
| if 'conn' in locals(): | |
| conn.close() | |
| def search_similar_chat_by_date( | |
| query: str, | |
| start_date: str = None, | |
| end_date: str = None, | |
| max_results: int = 100 | |
| ) -> List[Dict]: | |
| """ | |
| μ§μ λ λ μ§ λ²μ λ΄μ μ±ν λ°μ΄ν°λ₯Ό κ²μν©λλ€. | |
| Args: | |
| query (str): κ²μν 쿼리 ν μ€νΈ | |
| start_date (str): κ²μ μμ λ μ§ (YYYY-MM-DD νμ) | |
| end_date (str): κ²μ μ’ λ£ λ μ§ (YYYY-MM-DD νμ) | |
| max_results (int): λ°νν μ΅λ κ²°κ³Ό μ | |
| Returns: | |
| List[Dict]: κ²μ κ²°κ³Ό λͺ©λ‘ | |
| """ | |
| limit = max_results if max_results is not None else 100 | |
| # μλ°μ λμΌν κ°μ€μΉ μ€μ | |
| full_w = DEFAULT_FULL_WEIGHT | |
| topic_w = DEFAULT_TOPIC_WEIGHT | |
| customer_w = DEFAULT_CUSTOMER_WEIGHT | |
| agent_w = DEFAULT_AGENT_WEIGHT | |
| threshold = DEFAULT_SIMILARITY_THRESHOLD | |
| try: | |
| # 쿼리 μλ² λ© μμ± | |
| query_embedding = get_embedding(query) | |
| # PostgreSQL ν¬λ§·μΌλ‘ λ²‘ν° λ³ν | |
| query_vector = format_vector_for_pg(query_embedding) | |
| # DB μ°κ²° | |
| conn = get_db_conn() | |
| register_vector(conn) | |
| # μλ° μ½λμ λμΌν SQL 쿼리 μμ | |
| sql = """ | |
| WITH embeddings AS ( | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| CASE WHEN full_embedding IS NOT NULL THEN 1 - (full_embedding <=> '%s'::vector) ELSE 0 END * %f as full_sim, | |
| CASE WHEN topic_embedding IS NOT NULL THEN 1 - (topic_embedding <=> '%s'::vector) ELSE 0 END * %f as topic_sim, | |
| CASE WHEN customer_embedding IS NOT NULL THEN 1 - (customer_embedding <=> '%s'::vector) ELSE 0 END * %f as customer_sim, | |
| CASE WHEN agent_embedding IS NOT NULL THEN 1 - (agent_embedding <=> '%s'::vector) ELSE 0 END * %f as agent_sim | |
| FROM vector_store_multi_embeddings | |
| WHERE full_embedding IS NOT NULL | |
| OR topic_embedding IS NOT NULL | |
| OR customer_embedding IS NOT NULL | |
| OR agent_embedding IS NOT NULL | |
| """ % (query_vector, full_w, query_vector, topic_w, query_vector, customer_w, query_vector, agent_w) | |
| # λ μ§ νν° μΆκ° | |
| if start_date and start_date.strip(): | |
| # μμ μκ° μΆκ°νμ¬ ISO νμμΌλ‘ λΉκ΅ | |
| iso_start_date = start_date + "T00:00:00" | |
| sql += f" AND metadata->>'startTime' >= '{iso_start_date}'" | |
| if end_date and end_date.strip(): | |
| # μ’ λ£ μκ° μΆκ°νμ¬ ISO νμμΌλ‘ λΉκ΅ | |
| iso_end_date = end_date + "T23:59:59" | |
| sql += f" AND metadata->>'startTime' <= '{iso_end_date}'" | |
| sql += """ | |
| ) | |
| SELECT | |
| id, | |
| metadata, | |
| content, | |
| (full_sim + topic_sim + customer_sim + agent_sim) as combined_similarity | |
| FROM embeddings | |
| ORDER BY combined_similarity DESC | |
| LIMIT %s | |
| """ | |
| with conn.cursor() as cur: | |
| # μ¬κΈ°μλ limitλ₯Ό νλΌλ―Έν°λ‘ μ λ¬ | |
| cur.execute(sql, (limit,)) | |
| rows = cur.fetchall() | |
| results = [] | |
| for row in rows: | |
| id_val = row[0] | |
| metadata_json = row[1] | |
| content = row[2] | |
| similarity_score = float(row[3]) | |
| # λ©νλ°μ΄ν° νμ± | |
| try: | |
| metadata = json.loads(metadata_json) if isinstance(metadata_json, str) else metadata_json | |
| result = { | |
| "id": id_val, | |
| "similarityScore": similarity_score, | |
| "content": content, | |
| "chatId": get_text_value(metadata, "chatId"), | |
| "topic": get_text_value(metadata, "topic") | |
| } | |
| # μκ° νλ λ³ν μμ΄ κ·Έλλ‘ μ¬μ© (μ΄λ―Έ KSTλ‘ μ μ₯λμ΄ μμ) | |
| if "startTime" in metadata and metadata["startTime"] is not None: | |
| result["startTime"] = metadata["startTime"] | |
| if "endTime" in metadata and metadata["endTime"] is not None: | |
| result["endTime"] = metadata["endTime"] | |
| results.append(result) | |
| except Exception as e: | |
| print(f"λ©νλ°μ΄ν° νμ± μ€λ₯: {e}") | |
| continue | |
| # μκ³κ° νν°λ§ (μλ° μ½λμ λμΌνκ² κ΅¬ν) | |
| filtered_results = [r for r in results if r["similarityScore"] >= threshold] | |
| return filtered_results | |
| except Exception as e: | |
| print(f"λ€μ€ μλ² λ© λ μ§ κ²μ μ€ μ€λ₯ λ°μ: {str(e)}") | |
| return [] | |
| finally: | |
| if 'conn' in locals(): | |
| conn.close() | |
| # Gradio μΉ μΈν°νμ΄μ€ μ€μ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Chat Analysis Search") | |
| gr.Interface(fn=search_similar_chat, inputs=["text", "number"], outputs="json", api_name="search_similar_chat") | |
| gr.Interface(fn=search_similar_chat_by_date, inputs=["text", "text", "text", "number"], outputs="json", api_name="search_similar_chat_by_date") | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) |