File size: 3,277 Bytes
3e9bb33 1388133 3e9bb33 52a84f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import os
import argparse
from pathlib import Path
from crawl import *
from chatbot import *
def main(args):
# 허깅페이스 secret key로부터 api key 읽어오기
api_key = args.api_key
if not api_key:
api_key = str(os.environ.get("GROQ_API_KEY", ""))
print(f"Groq API Key를 성공적으로 불러왔습니다. 뒷자리 4글자 : ...{api_key[-4:]}")
if not api_key:
print("API Key가 설정되지 않았습니다. Hugging Face Secrets에서 'GROQ_API_KEY'를 설정하거나 --api_key 인자를 입력해주세요.")
# 크롤러 파트
abs_download_path = os.path.join(args.base_dir, args.download_dir)
abs_db_path = os.path.join(args.base_dir, args.db_dir)
collection = make_db(abs_download_path, abs_db_path, args.collection_name)
# 기본 임베딩 함수 외의 함수를 이용할 경우
#collection = make_db(abs_download_path, abs_db_path, args.collection_name, embedf_name = args.embedf_name)
crawl_seoultech_notice(abs_download_path, args.base_url, args.num_page, collection)
# 챗봇 파트
collection = get_chroma_collection(abs_db_path, args.collection_name)
# embedding function로 다른 모델을 사용할 경우
# collection = get_chroma_collection(abs_db_path, args.collection_name, embedf_name = args.embedf_name)
if collection is None:
print("Chromadb Collection을 불러오지 못했습니다. 프로그램을 종료합니다. ")
return
# 시스템 프롬프트 불러오기
system_prompt = get_system_prompt(args.prompt_type)
# 챗봇 실행
chat_with_rag(api_key = api_key,
collection = collection,
system_prompt = system_prompt,
args = args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# 공통 인자
parser.add_argument("--base_dir", type = str, default = str(Path(__file__).resolve().parent)) # 현재 이 파일이 있는 디렉토리
parser.add_argument("--db_dir", type = str, default = "seoultech_data_db")
parser.add_argument("--collection_name", type = str, default = "seoultech_notices")
parser.add_argument("--embedf_name", type = str, default = "BAAI/bge-m3")
# 크롤러
parser.add_argument("--base_url", type = str, default = "https://www.seoultech.ac.kr/service/info/notice")
parser.add_argument("--download_dir", type = str, default = "seoultech_data_download")
parser.add_argument("--header", type = dict, default = {"User-Agent" : "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"})
parser.add_argument("--num_page", type = int, default = 1)
# 챗봇
parser.add_argument("--api_key", type = str, default = "")
parser.add_argument("--log_dir", type = str, default = "chat_log")
parser.add_argument("--model_name", type = str, default = "llama-3.3-70b-versatile") # llama-3.1-8b-instant llama-3.3-70b-versatile openai/gpt-oss-120b
parser.add_argument("--temperature", type = float, default = 0.5)
parser.add_argument("--n_results", type = int, default = 3)
parser.add_argument("--prompt_type", type = str, default = "v")
args = parser.parse_args()
main(args) |