leedami's picture
Deploy from Team Script
41cc6f7 verified
import streamlit as st
import pandas as pd
import numpy as np
import pickle
import plotly.express as px
import plotly.graph_objects as go
import os
import json
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
import chromadb
# ๊ฒฝ๋กœ ์„ค์ •
import sys
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.append(BASE_DIR)
from scripts.utils.config import EMBEDDING_PKL_PATH, VECTOR_DB_DIR, MODEL_LLM_LATEST
st.set_page_config(layout="wide", page_title="Nyang Smart Retriever Debugger")
# --- Resource Loading ---
@st.cache_resource
def load_viz_data():
if not os.path.exists(EMBEDDING_PKL_PATH):
return None
with open(EMBEDDING_PKL_PATH, 'rb') as f:
return pickle.load(f)
@st.cache_resource
def load_models(embedding_model_path):
# 1. Embedding Model
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs={'device': 'cuda'},
encode_kwargs={'normalize_embeddings': True}
)
# 2. LLM (Query Parser) - 4bit Quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_LLM_LATEST)
model = AutoModelForCausalLM.from_pretrained(
MODEL_LLM_LATEST, quantization_config=bnb_config, device_map="auto"
)
llm_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128)
return embeddings, llm_pipe
# --- Core Logic ---
def parse_query_with_llm(pipe, query):
"""LLM์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž์—ฐ์–ด ์งˆ๋ฌธ์„ ๊ตฌ์กฐํ™”๋œ ๊ฒ€์ƒ‰ ์กฐ๊ฑด์œผ๋กœ ๋ณ€ํ™˜"""
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
๋„ˆ๋Š” ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ ๋ถ„์„๊ธฐ๋‹ค. ์‚ฌ์šฉ์ž ์งˆ๋ฌธ์—์„œ ๋ธŒ๋žœ๋“œ, ์นดํ…Œ๊ณ ๋ฆฌ, ๊ฐ€๊ฒฉ ์กฐ๊ฑด์„ ์ถ”์ถœํ•˜์—ฌ JSON์œผ๋กœ ์ถœ๋ ฅํ•˜๋ผ.
๊ฐ€๊ฒฉ์€ ์ˆซ์ž๋งŒ, ์—†๋Š” ์กฐ๊ฑด์€ null๋กœ ํ‘œ๊ธฐํ•ด๋ผ.
์˜ˆ์‹œ: "3๋งŒ์›๋Œ€ ๋กœ์–„์บ๋‹Œ ์‚ฌ๋ฃŒ"
์ถœ๋ ฅ: {{"brand": "๋กœ์–„์บ๋‹Œ", "category": "์‚ฌ๋ฃŒ", "price_min": 30000, "price_max": 40000, "query": "์‚ฌ๋ฃŒ ์ถ”์ฒœ"}}
<|eot_id|><|start_header_id|>user<|end_header_id|>
"{query}"
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
try:
outputs = pipe(prompt, do_sample=False)
generated = outputs[0]['generated_text'].split("assistant<|end_header_id|>")[-1].strip()
# JSON ๋ถ€๋ถ„๋งŒ ์ถ”์ถœ ์‹œ๋„
start = generated.find('{')
end = generated.rfind('}') + 1
return json.loads(generated[start:end])
except:
return {"brand": None, "category": None, "price_min": None, "price_max": None, "query": query}
def main():
st.title("๐Ÿฆ Nyang Smart Retriever (Thinking Process)")
# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
viz_data = load_viz_data()
if not viz_data:
st.error("์‹œ๊ฐํ™” ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.")
return
df = viz_data['dataframe']
model_name = "V2" # ์ตœ์‹  ๋ชจ๋ธ ๊ณ ์ •
reducer = viz_data['reducers'][model_name]
model_path = viz_data['models'][model_name]
x_col, y_col, z_col = f'x_{model_name}', f'y_{model_name}', f'z_{model_name}'
# ๋ชจ๋ธ ๋กœ๋“œ (์ตœ์ดˆ 1ํšŒ๋งŒ ์‹คํ–‰๋จ)
with st.spinner("AI ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘... (VRAM ํ™•๋ณด)"):
embeddings, llm_pipe = load_models(model_path)
# --- UI ---
col1, col2 = st.columns([1, 2])
with col1:
st.header("1. ์งˆ๋ฌธ ์ž…๋ ฅ")
query = st.text_input("์งˆ๋ฌธ", "3๋งŒ์›๋Œ€ ๋กœ์–„์บ๋‹Œ ๊ณ ์–‘์ด ์‚ฌ๋ฃŒ ๋ณด์—ฌ์ค˜")
if query:
st.header("2. AI ๋ถ„์„ (Thinking)")
# LLM ๋ถ„์„ ์ˆ˜ํ–‰
search_filter = parse_query_with_llm(llm_pipe, query)
st.json(search_filter)
# ํ•„ํ„ฐ ์กฐ๊ฑด ๊ตฌ์„ฑ
where_clause = {}
if search_filter.get('brand'):
# ๋ถ€๋ถ„ ์ผ์น˜๋ฅผ ์ง€์›ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ, ๋ฐ์ดํ„ฐ ์ •์ œ๊ฐ€ ์ค‘์š”ํ•จ.
# ์—ฌ๊ธฐ์„œ๋Š” ๋ฐ๋ชจ๋ฅผ ์œ„ํ•ด '$eq' ์‚ฌ์šฉ
where_clause['brand'] = search_filter['brand']
# ChromaDB ์—ฐ๊ฒฐ
client = chromadb.PersistentClient(path=VECTOR_DB_DIR)
collection = client.get_collection("product_search")
# ์ž„๋ฒ ๋”ฉ
q_vec = embeddings.embed_query(search_filter.get('query', query))
# ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰ (ํ•„ํ„ฐ ์ ์šฉ)
try:
results = collection.query(
query_embeddings=[q_vec],
n_results=5,
where=where_clause if where_clause else None
# ๊ฐ€๊ฒฉ ๋ฒ”์œ„ ํ•„ํ„ฐ๋Š” ChromaDB์˜ ๋ณตํ•ฉ where ์กฐ๊ฑด์ด ๊นŒ๋‹ค๋กœ์›Œ ํ›„์ฒ˜๋ฆฌ๋กœ ํ•˜๋Š” ๊ฒŒ ๋‚˜์„ ์ˆ˜ ์žˆ์Œ
)
st.header("3. ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ (Filtered)")
if results['ids']:
res_df = pd.DataFrame({
'์ƒํ’ˆ๋ช…': [m['product_name'] for m in results['metadatas'][0]],
'๋ธŒ๋žœ๋“œ': [m['brand'] for m in results['metadatas'][0]],
'๊ฐ€๊ฒฉ': [m['price'] for m in results['metadatas'][0]],
'๊ฑฐ๋ฆฌ': results['distances'][0]
})
st.table(res_df)
else:
st.warning("์กฐ๊ฑด์— ๋งž๋Š” ์ƒํ’ˆ์ด ์—†์Šต๋‹ˆ๋‹ค.")
except Exception as e:
st.error(f"๊ฒ€์ƒ‰ ์˜ค๋ฅ˜: {e}")
with col2:
st.header("4. ์‹œ๊ฐํ™” (Embedding Space)")
fig = px.scatter_3d(
df, x=x_col, y=y_col, z=z_col,
color='category',
hover_data=['product_name', 'price', 'brand'],
opacity=0.3, height=800,
title="Search Debugger View"
)
fig.update_traces(marker=dict(size=3))
if query:
# ์งˆ๋ฌธ ์œ„์น˜
q_proj = reducer.transform(np.array(q_vec).reshape(1, -1))
fig.add_trace(go.Scatter3d(
x=[q_proj[0, 0]], y=[q_proj[0, 1]], z=[q_proj[0, 2]],
mode='markers+text',
marker=dict(size=15, color='red', symbol='diamond'),
name='Query Intent'
))
# ๊ฒ€์ƒ‰๋œ ๊ฒฐ๊ณผ ๊ฐ•์กฐ
if 'results' in locals() and results['ids']:
found_ids = results['ids'][0] # ID ๋ฆฌ์ŠคํŠธ
# ID๊ฐ€ ๋งคํ•‘๋˜์ง€ ์•Š์•„ ์‹œ๊ฐํ™”๊ฐ€ ์–ด๋ ค์šธ ์ˆ˜ ์žˆ์Œ (prepare_data์—์„œ ID ์ €์žฅ ํ•„์š”)
# ์—ฌ๊ธฐ์„œ๋Š” ์ด๋ฆ„์œผ๋กœ ๋งค์นญ ์‹œ๋„ (๋ถˆ์™„์ „ํ•  ์ˆ˜ ์žˆ์Œ)
found_names = [m['product_name'] for m in results['metadatas'][0]]
found_df = df[df['product_name'].isin(found_names)]
fig.add_trace(go.Scatter3d(
x=found_df[x_col], y=found_df[y_col], z=found_df[z_col],
mode='markers',
marker=dict(size=8, color='yellow', line=dict(width=2, color='black')),
name='Filtered Results'
))
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
main()