|
|
import os
|
|
|
import csv
|
|
|
import time
|
|
|
import torch
|
|
|
import argparse
|
|
|
import chromadb
|
|
|
import datetime
|
|
|
import gradio as gr
|
|
|
|
|
|
from groq import Groq
|
|
|
from pathlib import Path
|
|
|
from prompt_db import *
|
|
|
from chromadb.utils import embedding_functions
|
|
|
|
|
|
|
|
|
def get_chroma_collection(db_path: str, collection_name: str, *, embedf_name: str = "") -> chromadb.Collection | None:
|
|
|
"""
|
|
|
ChromaDB ํด๋ผ์ด์ธํธ ๋ฐ ์ปฌ๋ ์
๋ก๋
|
|
|
input
|
|
|
dp_path : chromadb colletion์ด ์กด์ฌํ๋ ์ ๋ ๊ฒฝ๋ก
|
|
|
collection_name : chromadb colletion์ ์ด๋ฆ
|
|
|
output
|
|
|
collectoin : chromadb collection ๊ฐ์ฒด
|
|
|
"""
|
|
|
if not os.path.exists(db_path):
|
|
|
print(f"collection {collection_name} ์(๋ฅผ) ์ฐพ์ ์ ์์ต๋๋ค. ๊ฒฝ๋ก๋ฅผ ๋ค์ ํ์ธํด์ฃผ์ธ์.")
|
|
|
return None
|
|
|
|
|
|
chro_client = chromadb.PersistentClient(path=db_path)
|
|
|
|
|
|
if embedf_name:
|
|
|
embed_fun = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
|
model_name = embedf_name,
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
)
|
|
|
print(f"์๋ฒ ๋ฉ ํจ์๋ก {embedf_name} ๋ฅผ ์ฌ์ฉํฉ๋๋ค. ")
|
|
|
else:
|
|
|
embed_fun = embedding_functions.DefaultEmbeddingFunction()
|
|
|
print("์๋ฒ ๋ฉ ํจ์๋ก ๊ธฐ๋ณธ ์๋ฒ ๋ฉ ํจ์๋ฅผ ์ฌ์ฉํฉ๋๋ค. ")
|
|
|
|
|
|
|
|
|
try:
|
|
|
collection = chro_client.get_collection(
|
|
|
name = collection_name,
|
|
|
embedding_function = embed_fun)
|
|
|
print(f"Collection '{collection_name}' ์(๋ฅผ) ์ฑ๊ณต์ ์ผ๋ก ๋ถ๋ฌ์์ต๋๋ค. ")
|
|
|
return collection
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Collection '{collection_name}' ์(๋ฅผ) ๋ถ๋ฌ์ค์ง ๋ชปํ์ต๋๋ค : {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
def query_db(collection: chromadb.Collection,
|
|
|
query_text: str,
|
|
|
n_results: int) -> str:
|
|
|
"""
|
|
|
์ฌ์ฉ์ ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ๋ฌธ์๋ฅผ DB(collection)์์ ๊ฒ์ํ์ฌ ๋ฐํ
|
|
|
input
|
|
|
collection :
|
|
|
query_text :
|
|
|
n_results :
|
|
|
output
|
|
|
data : ์ฌ์ฉ์์ ์ง๋ฌธ๊ณผ ๊ด๋ จ๋ ๋ฌธ์
|
|
|
"""
|
|
|
if collection is None:
|
|
|
print("๋ฐ์ดํฐ๋ฒ ์ด์ค๊ฐ ์ฐ๊ฒฐ๋์ง ์์์ต๋๋ค.")
|
|
|
return ""
|
|
|
|
|
|
try:
|
|
|
results = collection.query(
|
|
|
query_texts = [query_text],
|
|
|
n_results = n_results
|
|
|
)
|
|
|
|
|
|
|
|
|
if not results["documents"] or not results["documents"][0]:
|
|
|
print("๊ด๋ จ๋ ๋ฌธ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.")
|
|
|
return ""
|
|
|
|
|
|
|
|
|
documents = results["documents"][0]
|
|
|
metadatas = results["metadatas"][0]
|
|
|
|
|
|
context_parts = []
|
|
|
for i, doc in enumerate(documents):
|
|
|
source = metadatas[i].get("title", "์ ๋ชฉ ์์")
|
|
|
date = metadatas[i].get("date", "๋ ์ง ์์")
|
|
|
context_parts.append(f"๋ฌธ์{i+1} [์ ๋ชฉ: {source}, ๋ ์ง: {date}]\n๋ด์ฉ : {doc}")
|
|
|
|
|
|
data = "\n\n".join(context_parts)
|
|
|
return data
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {e}")
|
|
|
return ""
|
|
|
|
|
|
|
|
|
def save_log(base_dir, log_dir, request, user_message, assistant_message):
|
|
|
"""
|
|
|
๋ํ ๋ก๊ทธ ์ ์ฅ ํจ์
|
|
|
"""
|
|
|
log_path = os.path.join(base_dir, log_dir)
|
|
|
|
|
|
if not os.path.exists(log_path):
|
|
|
os.mkdir(log_path)
|
|
|
print(f"{log_dir} ํด๋๊ฐ ์์ฑ๋์์ต๋๋ค : {log_path}")
|
|
|
|
|
|
|
|
|
|
|
|
today = datetime.datetime.now().strftime("%y%m%d")
|
|
|
file_name = f"chat_log_{today}.csv"
|
|
|
dest_file_path = os.path.join(log_path, file_name)
|
|
|
|
|
|
if not os.path.exists(dest_file_path):
|
|
|
with open(dest_file_path, mode = "w", newline = "", encoding = "utf-8") as file:
|
|
|
writer = csv.writer(file)
|
|
|
writer.writerow(["user_ip", "time_stamp", "user_message", "assistant_message"])
|
|
|
|
|
|
|
|
|
user_ip = request.client.host if request else "Unknown_IP"
|
|
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
|
user_conv_log = [user_ip, timestamp, user_message, assistant_message]
|
|
|
|
|
|
try:
|
|
|
with open(dest_file_path, mode = "a", newline = "", encoding = "utf-8") as file:
|
|
|
writer = csv.writer(file)
|
|
|
writer.writerow(user_conv_log)
|
|
|
except Exception as e:
|
|
|
print(f"๋ํ ๋ก๊ทธ ์ ์ฅ ์คํจ : {e}")
|
|
|
|
|
|
def get_response(user_message: str,
|
|
|
system_prompt: str,
|
|
|
collection: chromadb.Collection,
|
|
|
history: list[dict | list],
|
|
|
request: gr.Request,
|
|
|
client: Groq,
|
|
|
base_dir: str,
|
|
|
log_dir: str,
|
|
|
model_name: str,
|
|
|
n_results: int,
|
|
|
temperature: float):
|
|
|
|
|
|
if user_message.strip() == "๋๋":
|
|
|
end_message = "๋ํ๋ฅผ ์ข
๋ฃํฉ๋๋ค. ์ ๋ํ๋ฅผ ์์ํ๋ ค๋ฉด ์ค๋ฅธ์ชฝ ์๋จ์ Clear ๋ฒํผ(ํด์งํต ์์ด์ฝ)์ ํด๋ฆญํด์ฃผ์ธ์."
|
|
|
yield end_message
|
|
|
return
|
|
|
|
|
|
|
|
|
context = query_db(collection = collection,
|
|
|
query_text = user_message,
|
|
|
n_results= n_results)
|
|
|
|
|
|
|
|
|
formatted_system_prompt = system_prompt.format(context=context)
|
|
|
|
|
|
|
|
|
messages = [{"role": "system", "content": formatted_system_prompt}]
|
|
|
|
|
|
for chat in history:
|
|
|
if isinstance(chat, dict):
|
|
|
messages.append({"role": chat["role"], "content": chat["content"]})
|
|
|
|
|
|
elif isinstance(chat, list) and len(chat) == 2:
|
|
|
messages.append({"role": "user", "content": chat[0]})
|
|
|
messages.append({"role": "assistant", "content": chat[1]})
|
|
|
|
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
|
|
|
|
|
try:
|
|
|
response = client.chat.completions.create(
|
|
|
model = model_name,
|
|
|
messages = messages,
|
|
|
temperature = temperature,
|
|
|
stream = True
|
|
|
)
|
|
|
|
|
|
|
|
|
assistant_message = ""
|
|
|
for chunk in response:
|
|
|
delta = chunk.choices[0].delta.content
|
|
|
if delta:
|
|
|
assistant_message += delta
|
|
|
yield assistant_message
|
|
|
|
|
|
except Exception as e:
|
|
|
error_message = f"๋ต๋ณ ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. : {str(e)}"
|
|
|
yield error_message
|
|
|
assistant_message = error_message
|
|
|
|
|
|
save_log(base_dir, log_dir, request, user_message, assistant_message)
|
|
|
|
|
|
|
|
|
def chat_with_rag(api_key: str,
|
|
|
collection: chromadb.Collection,
|
|
|
system_prompt: str,
|
|
|
args: argparse.ArgumentParser) -> None:
|
|
|
"""
|
|
|
RAG ์ฑ๋ด ์คํ
|
|
|
input
|
|
|
dd
|
|
|
output
|
|
|
-
|
|
|
"""
|
|
|
try:
|
|
|
groq_client = Groq(api_key = api_key)
|
|
|
except Exception as e:
|
|
|
print(f"Groq client๋ฅผ ๋ถ๋ฌ์ค์ง ๋ชปํ์ต๋๋ค. API Key๋ฅผ ํ์ธํด์ฃผ์ธ์ : {e}")
|
|
|
|
|
|
def predict(user_message, history, request: gr.Request):
|
|
|
yield from get_response(
|
|
|
user_message = user_message,
|
|
|
system_prompt = system_prompt,
|
|
|
collection = collection,
|
|
|
history = history,
|
|
|
request = request,
|
|
|
client = groq_client,
|
|
|
base_dir = args.base_dir,
|
|
|
log_dir = args.log_dir,
|
|
|
model_name = args.model_name,
|
|
|
n_results = args.n_results,
|
|
|
temperature = args.temperature
|
|
|
)
|
|
|
|
|
|
title = "ChaTech"
|
|
|
description = """
|
|
|
์์ธ๊ณผํ๊ธฐ์ ๋ํ๊ต ๊ณต์ง์ฌํญ ๊ธฐ๋ฐ ์ง์์๋ต ์ฑ๋ด์
๋๋ค.
|
|
|
๋ฐ์ดํฐ๋ฒ ์ด์ค์ ์ ์ฅ๋ ๊ณต์ง์ฌํญ ๋ด์ฉ์ ๋ฐํ์ผ๋ก ๋ต๋ณํฉ๋๋ค.
|
|
|
๋ํ ์ข
๋ฃ๋ฅผ ์ํ์ค ๊ฒฝ์ฐ ์ฑํ
์ฐฝ์ \'๋๋\'์ ์
๋ ฅํด์ฃผ์ธ์.
|
|
|
"""
|
|
|
|
|
|
demo = gr.ChatInterface(
|
|
|
fn = predict,
|
|
|
title = title,
|
|
|
description = description
|
|
|
).queue()
|
|
|
|
|
|
demo.launch(debug = True, share = True)
|
|
|
|
|
|
|
|
|
|
|
|
def get_system_prompt(prompt_type: str) -> str:
|
|
|
"""
|
|
|
prompt_db.py๋ก๋ถํฐ ์์คํ
ํ๋กฌํํธ๋ฅผ ๋ถ๋ฌ์์ ๋ฐํ
|
|
|
input
|
|
|
prompt_type : ์ฌ์ฉํ ์์คํ
ํ๋กฌํํธ ์ข
๋ฅ
|
|
|
v : vanilla prompt
|
|
|
adv1 : advanced prompt ver.1 (๋ฏธ๊ตฌํ)
|
|
|
output
|
|
|
system_prompt : ์์คํ
ํ๋กฌํํธ ์ ๋ฌธ
|
|
|
"""
|
|
|
|
|
|
if prompt_type == "v":
|
|
|
vanilla = Vanilla()
|
|
|
system_prompt = vanilla.get_prompt()
|
|
|
return system_prompt
|
|
|
|
|
|
|
|
|
elif prompt_type == "adv1":
|
|
|
system_prompt = ""
|
|
|
return system_prompt
|
|
|
else:
|
|
|
print("์ ํจํ์ง ์์ ํ๋กฌํํธ ํ์
์
๋๋ค. ๊ธฐ๋ณธ๊ฐ(Vanilla)์ ์ฌ์ฉํฉ๋๋ค. ")
|
|
|
system_prompt = vanilla.get_prompt()
|
|
|
return system_prompt
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
|
|
abs_db_path = os.path.join(args.base_dir, args.db_dir)
|
|
|
|
|
|
|
|
|
collection = get_chroma_collection(abs_db_path, args.collection_name)
|
|
|
|
|
|
|
|
|
|
|
|
if collection is None:
|
|
|
print("Chromadb Collection์ ๋ถ๋ฌ์ค์ง ๋ชปํ์ต๋๋ค. ํ๋ก๊ทธ๋จ์ ์ข
๋ฃํฉ๋๋ค. ")
|
|
|
return
|
|
|
|
|
|
|
|
|
system_prompt = get_system_prompt(args.prompt_type)
|
|
|
|
|
|
|
|
|
chat_with_rag(api_key = args.api_key,
|
|
|
collection = collection,
|
|
|
system_prompt = system_prompt,
|
|
|
args = args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--api_key", type = str, default = "")
|
|
|
parser.add_argument("--base_dir", type = str, default = str(Path(__file__).resolve().parent))
|
|
|
parser.add_argument("--db_dir", type = str, default = "seoultech_data_db")
|
|
|
parser.add_argument("--log_dir", type = str, default = "chat_log")
|
|
|
parser.add_argument("--model_name", type = str, default = "llama-3.3-70b-versatile")
|
|
|
parser.add_argument("--temperature", type = float, default = 0.5)
|
|
|
parser.add_argument("--n_results", type = int, default = 3)
|
|
|
parser.add_argument("--collection_name", type = str, default = "seoultech_notices")
|
|
|
parser.add_argument("--embedf_name", type = str, default = "BAAI/bge-m3")
|
|
|
parser.add_argument("--prompt_type", type = str, default = "v")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|