Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import pickle | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import os | |
| import json | |
| import torch | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig | |
| import chromadb | |
| # ๊ฒฝ๋ก ์ค์ | |
| import sys | |
| BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) | |
| sys.path.append(BASE_DIR) | |
| from scripts.utils.config import EMBEDDING_PKL_PATH, VECTOR_DB_DIR, MODEL_LLM_LATEST | |
| st.set_page_config(layout="wide", page_title="Nyang Smart Retriever Debugger") | |
| # --- Resource Loading --- | |
| def load_viz_data(): | |
| if not os.path.exists(EMBEDDING_PKL_PATH): | |
| return None | |
| with open(EMBEDDING_PKL_PATH, 'rb') as f: | |
| return pickle.load(f) | |
| def load_models(embedding_model_path): | |
| # 1. Embedding Model | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=embedding_model_path, | |
| model_kwargs={'device': 'cuda'}, | |
| encode_kwargs={'normalize_embeddings': True} | |
| ) | |
| # 2. LLM (Query Parser) - 4bit Quantization | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_LLM_LATEST) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_LLM_LATEST, quantization_config=bnb_config, device_map="auto" | |
| ) | |
| llm_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128) | |
| return embeddings, llm_pipe | |
| # --- Core Logic --- | |
| def parse_query_with_llm(pipe, query): | |
| """LLM์ ์ฌ์ฉํ์ฌ ์์ฐ์ด ์ง๋ฌธ์ ๊ตฌ์กฐํ๋ ๊ฒ์ ์กฐ๊ฑด์ผ๋ก ๋ณํ""" | |
| prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
| ๋๋ ๊ฒ์ ์ฟผ๋ฆฌ ๋ถ์๊ธฐ๋ค. ์ฌ์ฉ์ ์ง๋ฌธ์์ ๋ธ๋๋, ์นดํ ๊ณ ๋ฆฌ, ๊ฐ๊ฒฉ ์กฐ๊ฑด์ ์ถ์ถํ์ฌ JSON์ผ๋ก ์ถ๋ ฅํ๋ผ. | |
| ๊ฐ๊ฒฉ์ ์ซ์๋ง, ์๋ ์กฐ๊ฑด์ null๋ก ํ๊ธฐํด๋ผ. | |
| ์์: "3๋ง์๋ ๋ก์์บ๋ ์ฌ๋ฃ" | |
| ์ถ๋ ฅ: {{"brand": "๋ก์์บ๋", "category": "์ฌ๋ฃ", "price_min": 30000, "price_max": 40000, "query": "์ฌ๋ฃ ์ถ์ฒ"}} | |
| <|eot_id|><|start_header_id|>user<|end_header_id|> | |
| "{query}" | |
| <|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
| """ | |
| try: | |
| outputs = pipe(prompt, do_sample=False) | |
| generated = outputs[0]['generated_text'].split("assistant<|end_header_id|>")[-1].strip() | |
| # JSON ๋ถ๋ถ๋ง ์ถ์ถ ์๋ | |
| start = generated.find('{') | |
| end = generated.rfind('}') + 1 | |
| return json.loads(generated[start:end]) | |
| except: | |
| return {"brand": None, "category": None, "price_min": None, "price_max": None, "query": query} | |
| def main(): | |
| st.title("๐ฆ Nyang Smart Retriever (Thinking Process)") | |
| # ๋ฐ์ดํฐ ๋ก๋ | |
| viz_data = load_viz_data() | |
| if not viz_data: | |
| st.error("์๊ฐํ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") | |
| return | |
| df = viz_data['dataframe'] | |
| model_name = "V2" # ์ต์ ๋ชจ๋ธ ๊ณ ์ | |
| reducer = viz_data['reducers'][model_name] | |
| model_path = viz_data['models'][model_name] | |
| x_col, y_col, z_col = f'x_{model_name}', f'y_{model_name}', f'z_{model_name}' | |
| # ๋ชจ๋ธ ๋ก๋ (์ต์ด 1ํ๋ง ์คํ๋จ) | |
| with st.spinner("AI ๋ชจ๋ธ ๋ก๋ฉ ์ค... (VRAM ํ๋ณด)"): | |
| embeddings, llm_pipe = load_models(model_path) | |
| # --- UI --- | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.header("1. ์ง๋ฌธ ์ ๋ ฅ") | |
| query = st.text_input("์ง๋ฌธ", "3๋ง์๋ ๋ก์์บ๋ ๊ณ ์์ด ์ฌ๋ฃ ๋ณด์ฌ์ค") | |
| if query: | |
| st.header("2. AI ๋ถ์ (Thinking)") | |
| # LLM ๋ถ์ ์ํ | |
| search_filter = parse_query_with_llm(llm_pipe, query) | |
| st.json(search_filter) | |
| # ํํฐ ์กฐ๊ฑด ๊ตฌ์ฑ | |
| where_clause = {} | |
| if search_filter.get('brand'): | |
| # ๋ถ๋ถ ์ผ์น๋ฅผ ์ง์ํ์ง ์์ผ๋ฏ๋ก, ๋ฐ์ดํฐ ์ ์ ๊ฐ ์ค์ํจ. | |
| # ์ฌ๊ธฐ์๋ ๋ฐ๋ชจ๋ฅผ ์ํด '$eq' ์ฌ์ฉ | |
| where_clause['brand'] = search_filter['brand'] | |
| # ChromaDB ์ฐ๊ฒฐ | |
| client = chromadb.PersistentClient(path=VECTOR_DB_DIR) | |
| collection = client.get_collection("product_search") | |
| # ์๋ฒ ๋ฉ | |
| q_vec = embeddings.embed_query(search_filter.get('query', query)) | |
| # ๊ฒ์ ์ํ (ํํฐ ์ ์ฉ) | |
| try: | |
| results = collection.query( | |
| query_embeddings=[q_vec], | |
| n_results=5, | |
| where=where_clause if where_clause else None | |
| # ๊ฐ๊ฒฉ ๋ฒ์ ํํฐ๋ ChromaDB์ ๋ณตํฉ where ์กฐ๊ฑด์ด ๊น๋ค๋ก์ ํ์ฒ๋ฆฌ๋ก ํ๋ ๊ฒ ๋์ ์ ์์ | |
| ) | |
| st.header("3. ๊ฒ์ ๊ฒฐ๊ณผ (Filtered)") | |
| if results['ids']: | |
| res_df = pd.DataFrame({ | |
| '์ํ๋ช ': [m['product_name'] for m in results['metadatas'][0]], | |
| '๋ธ๋๋': [m['brand'] for m in results['metadatas'][0]], | |
| '๊ฐ๊ฒฉ': [m['price'] for m in results['metadatas'][0]], | |
| '๊ฑฐ๋ฆฌ': results['distances'][0] | |
| }) | |
| st.table(res_df) | |
| else: | |
| st.warning("์กฐ๊ฑด์ ๋ง๋ ์ํ์ด ์์ต๋๋ค.") | |
| except Exception as e: | |
| st.error(f"๊ฒ์ ์ค๋ฅ: {e}") | |
| with col2: | |
| st.header("4. ์๊ฐํ (Embedding Space)") | |
| fig = px.scatter_3d( | |
| df, x=x_col, y=y_col, z=z_col, | |
| color='category', | |
| hover_data=['product_name', 'price', 'brand'], | |
| opacity=0.3, height=800, | |
| title="Search Debugger View" | |
| ) | |
| fig.update_traces(marker=dict(size=3)) | |
| if query: | |
| # ์ง๋ฌธ ์์น | |
| q_proj = reducer.transform(np.array(q_vec).reshape(1, -1)) | |
| fig.add_trace(go.Scatter3d( | |
| x=[q_proj[0, 0]], y=[q_proj[0, 1]], z=[q_proj[0, 2]], | |
| mode='markers+text', | |
| marker=dict(size=15, color='red', symbol='diamond'), | |
| name='Query Intent' | |
| )) | |
| # ๊ฒ์๋ ๊ฒฐ๊ณผ ๊ฐ์กฐ | |
| if 'results' in locals() and results['ids']: | |
| found_ids = results['ids'][0] # ID ๋ฆฌ์คํธ | |
| # ID๊ฐ ๋งคํ๋์ง ์์ ์๊ฐํ๊ฐ ์ด๋ ค์ธ ์ ์์ (prepare_data์์ ID ์ ์ฅ ํ์) | |
| # ์ฌ๊ธฐ์๋ ์ด๋ฆ์ผ๋ก ๋งค์นญ ์๋ (๋ถ์์ ํ ์ ์์) | |
| found_names = [m['product_name'] for m in results['metadatas'][0]] | |
| found_df = df[df['product_name'].isin(found_names)] | |
| fig.add_trace(go.Scatter3d( | |
| x=found_df[x_col], y=found_df[y_col], z=found_df[z_col], | |
| mode='markers', | |
| marker=dict(size=8, color='yellow', line=dict(width=2, color='black')), | |
| name='Filtered Results' | |
| )) | |
| st.plotly_chart(fig, use_container_width=True) | |
| if __name__ == "__main__": | |
| main() |