ChaTech / chatbot.py
m00k10m's picture
Upload 5 files
52a84f2 verified
import os
import csv
import time
import torch
import argparse
import chromadb
import datetime
import gradio as gr
from groq import Groq
from pathlib import Path
from prompt_db import *
from chromadb.utils import embedding_functions
def get_chroma_collection(db_path: str, collection_name: str, *, embedf_name: str = "") -> chromadb.Collection | None:
"""
ChromaDB ํด๋ผ์ด์–ธํŠธ ๋ฐ ์ปฌ๋ ‰์…˜ ๋กœ๋“œ
input
dp_path : chromadb colletion์ด ์กด์žฌํ•˜๋Š” ์ ˆ๋Œ€ ๊ฒฝ๋กœ
collection_name : chromadb colletion์˜ ์ด๋ฆ„
output
collectoin : chromadb collection ๊ฐ์ฒด
"""
if not os.path.exists(db_path):
print(f"collection {collection_name} ์„(๋ฅผ) ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ฒฝ๋กœ๋ฅผ ๋‹ค์‹œ ํ™•์ธํ•ด์ฃผ์„ธ์š”.")
return None
chro_client = chromadb.PersistentClient(path=db_path)
if embedf_name:
embed_fun = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name = embedf_name,
device = "cuda" if torch.cuda.is_available() else "cpu"
)
print(f"์ž„๋ฒ ๋”ฉ ํ•จ์ˆ˜๋กœ {embedf_name} ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ")
else:
embed_fun = embedding_functions.DefaultEmbeddingFunction()
print("์ž„๋ฒ ๋”ฉ ํ•จ์ˆ˜๋กœ ๊ธฐ๋ณธ ์ž„๋ฒ ๋”ฉ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ")
# ๊ธฐ์กด collection ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
try:
collection = chro_client.get_collection(
name = collection_name,
embedding_function = embed_fun)
print(f"Collection '{collection_name}' ์„(๋ฅผ) ์„ฑ๊ณต์ ์œผ๋กœ ๋ถˆ๋Ÿฌ์™”์Šต๋‹ˆ๋‹ค. ")
return collection
except Exception as e:
print(f"Collection '{collection_name}' ์„(๋ฅผ) ๋ถˆ๋Ÿฌ์˜ค์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค : {e}")
return None
def query_db(collection: chromadb.Collection,
query_text: str,
n_results: int) -> str:
"""
์‚ฌ์šฉ์ž ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ๋ฌธ์„œ๋ฅผ DB(collection)์—์„œ ๊ฒ€์ƒ‰ํ•˜์—ฌ ๋ฐ˜ํ™˜
input
collection :
query_text :
n_results :
output
data : ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ ๋ฌธ์„œ
"""
if collection is None:
print("๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค๊ฐ€ ์—ฐ๊ฒฐ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
return ""
try:
results = collection.query(
query_texts = [query_text],
n_results = n_results
)
# ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ
if not results["documents"] or not results["documents"][0]:
print("๊ด€๋ จ๋œ ๋ฌธ์„œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
return ""
# ๊ฒ€์ƒ‰๋œ ๋ฌธ์„œ๋“ค์„ ํ•˜๋‚˜์˜ ๋ฌธ์ž์—ด๋กœ ๊ฒฐํ•ฉ
documents = results["documents"][0]
metadatas = results["metadatas"][0]
context_parts = []
for i, doc in enumerate(documents):
source = metadatas[i].get("title", "์ œ๋ชฉ ์—†์Œ")
date = metadatas[i].get("date", "๋‚ ์งœ ์—†์Œ")
context_parts.append(f"๋ฌธ์„œ{i+1} [์ œ๋ชฉ: {source}, ๋‚ ์งœ: {date}]\n๋‚ด์šฉ : {doc}")
data = "\n\n".join(context_parts)
return data
except Exception as e:
print(f"๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return ""
def save_log(base_dir, log_dir, request, user_message, assistant_message):
"""
๋Œ€ํ™” ๋กœ๊ทธ ์ €์žฅ ํ•จ์ˆ˜
"""
log_path = os.path.join(base_dir, log_dir)
if not os.path.exists(log_path):
os.mkdir(log_path)
print(f"{log_dir} ํด๋”๊ฐ€ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค : {log_path}")
# ํ˜„์žฌ ๊ฒฝ๋กœ ๋‚ด์— ์žˆ๋Š” {log_dir} ํด๋” ๋‚ด์— ๋Œ€ํ™” ๋กœ๊ทธ ํŒŒ์ผ์ด ์—†๋Š” ๊ฒฝ์šฐ -> csvํŒŒ์ผ ์ƒ์„ฑ
# ๊ฐ csv ํŒŒ์ผ์€ ๋‚ ์งœ๋ณ„๋กœ ๊ตฌ๋ถ„
today = datetime.datetime.now().strftime("%y%m%d")
file_name = f"chat_log_{today}.csv"
dest_file_path = os.path.join(log_path, file_name)
if not os.path.exists(dest_file_path):
with open(dest_file_path, mode = "w", newline = "", encoding = "utf-8") as file:
writer = csv.writer(file)
writer.writerow(["user_ip", "time_stamp", "user_message", "assistant_message"])
# ์ฑ—๋ด‡๊ณผ์˜ ๋Œ€ํ™” ๋กœ๊ทธ๋ฅผ ๊ธฐ๋ก
user_ip = request.client.host if request else "Unknown_IP"
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
user_conv_log = [user_ip, timestamp, user_message, assistant_message]
try:
with open(dest_file_path, mode = "a", newline = "", encoding = "utf-8") as file:
writer = csv.writer(file)
writer.writerow(user_conv_log)
except Exception as e:
print(f"๋Œ€ํ™” ๋กœ๊ทธ ์ €์žฅ ์‹คํŒจ : {e}")
def get_response(user_message: str,
system_prompt: str,
collection: chromadb.Collection,
history: list[dict | list],
request: gr.Request,
client: Groq,
base_dir: str,
log_dir: str,
model_name: str,
n_results: int,
temperature: float):
if user_message.strip() == "๋๋":
end_message = "๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค. ์ƒˆ ๋Œ€ํ™”๋ฅผ ์‹œ์ž‘ํ•˜๋ ค๋ฉด ์˜ค๋ฅธ์ชฝ ์ƒ๋‹จ์˜ Clear ๋ฒ„ํŠผ(ํœด์ง€ํ†ต ์•„์ด์ฝ˜)์„ ํด๋ฆญํ•ด์ฃผ์„ธ์š”."
yield end_message
return
# RAG: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ๊ณผ ๊ด€๋ จ๋œ Context ๊ฒ€์ƒ‰
context = query_db(collection = collection,
query_text = user_message,
n_results= n_results)
# System Prompt์— Context ์ฃผ์ž…
formatted_system_prompt = system_prompt.format(context=context)
# ๋ฉ”์‹œ์ง€ ๊ตฌ์„ฑ
messages = [{"role": "system", "content": formatted_system_prompt}]
for chat in history:
if isinstance(chat, dict):
messages.append({"role": chat["role"], "content": chat["content"]})
# ๊ตฌ๋ฒ„์ „ gradio ์œ„ํ•จ
elif isinstance(chat, list) and len(chat) == 2:
messages.append({"role": "user", "content": chat[0]})
messages.append({"role": "assistant", "content": chat[1]})
messages.append({"role": "user", "content": user_message})
# LLM์—๊ฒŒ ๋‹ต๋ณ€ ์ƒ์„ฑ ์š”์ฒญ
try:
response = client.chat.completions.create(
model = model_name,
messages = messages,
temperature = temperature,
stream = True
)
# ์‚ฌ์šฉ์ž์—๊ฒŒ ์ฑ—๋ด‡์˜ ๋‹ต๋ณ€์ด ์‹ค์‹œ๊ฐ„์œผ๋กœ ์ž…๋ ฅ๋˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ๋ณด์—ฌ์คŒ
assistant_message = ""
for chunk in response:
delta = chunk.choices[0].delta.content
if delta:
assistant_message += delta
yield assistant_message
except Exception as e:
error_message = f"๋‹ต๋ณ€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. : {str(e)}"
yield error_message
assistant_message = error_message
save_log(base_dir, log_dir, request, user_message, assistant_message)
def chat_with_rag(api_key: str,
collection: chromadb.Collection,
system_prompt: str,
args: argparse.ArgumentParser) -> None:
"""
RAG ์ฑ—๋ด‡ ์‹คํ–‰
input
dd
output
-
"""
try:
groq_client = Groq(api_key = api_key)
except Exception as e:
print(f"Groq client๋ฅผ ๋ถˆ๋Ÿฌ์˜ค์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. API Key๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š” : {e}")
def predict(user_message, history, request: gr.Request):
yield from get_response(
user_message = user_message,
system_prompt = system_prompt,
collection = collection,
history = history,
request = request,
client = groq_client,
base_dir = args.base_dir,
log_dir = args.log_dir,
model_name = args.model_name,
n_results = args.n_results,
temperature = args.temperature
)
title = "ChaTech"
description = """
์„œ์šธ๊ณผํ•™๊ธฐ์ˆ ๋Œ€ํ•™๊ต ๊ณต์ง€์‚ฌํ•ญ ๊ธฐ๋ฐ˜ ์งˆ์˜์‘๋‹ต ์ฑ—๋ด‡์ž…๋‹ˆ๋‹ค.
๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์ €์žฅ๋œ ๊ณต์ง€์‚ฌํ•ญ ๋‚ด์šฉ์„ ๋ฐ”ํƒ•์œผ๋กœ ๋‹ต๋ณ€ํ•ฉ๋‹ˆ๋‹ค.
๋Œ€ํ™” ์ข…๋ฃŒ๋ฅผ ์›ํ•˜์‹ค ๊ฒฝ์šฐ ์ฑ„ํŒ…์ฐฝ์— \'๋๋\'์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.
"""
demo = gr.ChatInterface(
fn = predict,
title = title,
description = description
).queue()
demo.launch(debug = True, share = True)
def get_system_prompt(prompt_type: str) -> str:
"""
prompt_db.py๋กœ๋ถ€ํ„ฐ ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์™€์„œ ๋ฐ˜ํ™˜
input
prompt_type : ์‚ฌ์šฉํ•  ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ์ข…๋ฅ˜
v : vanilla prompt
adv1 : advanced prompt ver.1 (๋ฏธ๊ตฌํ˜„)
output
system_prompt : ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ์ „๋ฌธ
"""
if prompt_type == "v":
vanilla = Vanilla()
system_prompt = vanilla.get_prompt()
return system_prompt
# ๊ฐœ์„ ๋œ ํ”„๋กฌํ”„ํŠธ ๋ฒ„์ „, ์•„์ง ๋ฏธ๊ตฌํ˜„
elif prompt_type == "adv1":
system_prompt = ""
return system_prompt
else:
print("์œ ํšจํ•˜์ง€ ์•Š์€ ํ”„๋กฌํ”„ํŠธ ํƒ€์ž…์ž…๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’(Vanilla)์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ")
system_prompt = vanilla.get_prompt()
return system_prompt
def main(args):
# chromadb collection ๊ฒฝ๋กœ ์„ค์ •
abs_db_path = os.path.join(args.base_dir, args.db_dir)
# collection ๊ฐ์ฒด ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
collection = get_chroma_collection(abs_db_path, args.collection_name)
# embedding function๋กœ ๋‹ค๋ฅธ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ
# collection = get_chroma_collection(abs_db_path, args.collection_name, embedf_name = args.embedf_name)
if collection is None:
print("Chromadb Collection์„ ๋ถˆ๋Ÿฌ์˜ค์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ํ”„๋กœ๊ทธ๋žจ์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค. ")
return
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
system_prompt = get_system_prompt(args.prompt_type)
# ์ฑ—๋ด‡ ์‹คํ–‰
chat_with_rag(api_key = args.api_key,
collection = collection,
system_prompt = system_prompt,
args = args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api_key", type = str, default = "")
parser.add_argument("--base_dir", type = str, default = str(Path(__file__).resolve().parent))
parser.add_argument("--db_dir", type = str, default = "seoultech_data_db")
parser.add_argument("--log_dir", type = str, default = "chat_log")
parser.add_argument("--model_name", type = str, default = "llama-3.3-70b-versatile") # llama-3.1-8b-instant llama-3.3-70b-versatile openai/gpt-oss-120b
parser.add_argument("--temperature", type = float, default = 0.5)
parser.add_argument("--n_results", type = int, default = 3)
parser.add_argument("--collection_name", type = str, default = "seoultech_notices")
parser.add_argument("--embedf_name", type = str, default = "BAAI/bge-m3")
parser.add_argument("--prompt_type", type = str, default = "v")
args = parser.parse_args()
main(args)