lexistudio / app.py
scipious's picture
Update app.py
5777af7 verified
raw
history blame
23.4 kB
import os
#os.environ["PYDANTIC_V1_STYLE"] = "1"
#os.environ["PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS"] = "1"
# --------------------------------------------------------------------------
from flask import Flask, render_template, jsonify, request, Response
from flask_socketio import SocketIO, emit
import uuid
import threading
import sqlite3
import gc
import time
import re
import traceback
import requests # API 호출을 위해 필요
# --- Together AI SDK ---
from together import Together
# --- eventlet monkey patch (Gunicorn + SocketIO 필수!) ---
import eventlet
eventlet.monkey_patch()
# --- Flask & SocketIO 설정 ---
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
import logging
# 로거 설정: 레벨을 INFO로 설정하고, 포맷을 지정합니다.
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# --- 외부 모듈 임포트 ---
import reg_embedding_system
import leximind_prompts
# --- 전역 변수 ---
connected_clients = 0
search_document_number = 30
Filtered_search = False
filters = {"regulation_part": []}
# --- 경로 설정 ---
current_dir = os.path.dirname(os.path.abspath(__file__))
ResultFile_FolderAddress = os.path.join(current_dir, 'result.txt')
# --- RAG 데이터 경로 ---
# NOTE: Hugging Face Spaces에서 데이터가 /app/data에 있는지 확인해야 합니다.
region_paths = {
"국내": "/app/data/KMVSS_RAG",
"북미": "/app/data/FMVSS_RAG",
"유럽": "/app/data/EUR_RAG"
}
# --- 프롬프트 ---
lexi_prompts = leximind_prompts.PromptLibrary()
# 세션별 요청 추적을 위한 딕셔너리
active_sessions = {}
# --- RAG 객체 ---
region_rag_objects = {}
# --- Together AI 설정 (SDK 대신 API 호출에 사용) ---
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
if not TOGETHER_API_KEY:
raise EnvironmentError("TOGETHER_API_KEY가 설정되지 않았습니다. Hugging Face Secrets에 추가하세요.")
# client = Together(api_key=TOGETHER_API_KEY) # <--- Together SDK 클라이언트 제거
try:
# TOGETHER_API_KEY를 사용해 클라이언트 초기화 (TOGETHER_API_KEY가 코드 내에 정의되어 있다고 가정)
client = Together(api_key=TOGETHER_API_KEY)
except NameError:
# TOGETHER_API_KEY가 정의되지 않은 경우 환경 변수 사용을 시도
client = Together()
rag_connection_status_info = ""
# --- RAG 로딩 ---
def load_rag_objects():
global region_rag_objects
global rag_connection_status_info
# 로딩 스레드 시작 로그를 추가하여 Gunicorn 로그에서 확인 가능하게 함
logger.info(">>> [RAG_LOADER] RAG 로딩 스레드 시작 <<<")
for region, path in region_paths.items():
if not os.path.exists(path):
msg = f"[{region}] 경로 없음: {path}"
socketio.emit('message', {'message': msg})
logger.info(msg)
continue
try:
socketio.emit('message', {'message': f"[{region}] RAG 로딩 중..."})
rag_connection_status_info = f"[{region}] RAG 로딩 중..."
# NOTE: reg_embedding_system 모듈이 현재 환경에 설치/존재하는지 확인해야 합니다.
ensemble_retriever, vectorstore, sqlite_conn = reg_embedding_system.load_embedding_from_faiss(path)
sqlite_conn.close()
db_path = os.path.join(path, "metadata_mapping.db")
new_conn = sqlite3.connect(db_path, check_same_thread=False)
region_rag_objects[region] = {
"ensemble_retriever": ensemble_retriever,
"vectorstore": vectorstore,
"sqlite_conn": new_conn
}
socketio.emit('message', {'message': f"[{region}] 로딩 완료"})
logger.info(f"[{region}] RAG 로딩 완료")
rag_connection_status_info = f"[{region}] RAG 로딩 완료"
except Exception as e:
error_msg = f"[{region}] 로딩 실패: {str(e)}"
logger.info(error_msg)
# [수정]: 상세한 에러 추적을 위해 traceback 추가
traceback.logger.info_exc()
socketio.emit('message', {'message': error_msg})
socketio.emit('message', {'message': "Ready to Search"})
logger.info("Ready to Search")
rag_connection_status_info = "Ready to Search"
# --- 웹 ---
@app.route('/')
def index():
return render_template('chat_v02.html')
# 전역 변수에 기본값 추가
Search_each_all_mode = True # 기본값을 클라이언트에서 제어 가능
@socketio.on('search_query')
def handle_search_query(data):
global Filtered_search
global filters
global Search_each_all_mode
# 세션 ID 생성
session_id = str(uuid.uuid4())
active_sessions[session_id] = True
# 클라이언트에 session_id 전달
emit('search_started', {'session_id': session_id})
try:
# 클라이언트에서 전송된 검색 모드 사용
Search_each_all_mode = data.get('searchEachMode', True)
query = data.get('query', '')
regions = data.get('regions', [])
selected_regulations = data.get('selectedRegulations', [])
emit('search_status', {'status': 'processing', 'message': '검색 요청을 처리하는 중입니다...'})
logger.info(f"선택된 지역 : {regions}")
logger.info(f"선택된 법규 : {selected_regulations}")
if Search_each_all_mode:
logger.info(f"검색 모드 : 각각 검색")
else:
logger.info(f"검색 모드 : 통합 검색")
# 법규 타입별로 필터 구분
filters = {
"regulation_part": [],
"regulation_section": [],
"chapter_section": [],
"jo": []
}
# 번역 진행 상황 알림
emit('search_status', {'status': 'translating', 'message': '질문을 번역하는 중입니다...'})
if session_id not in active_sessions:
emit('search_cancelled', {'message': '검색이 취소되었습니다.'})
emit('search_status', {'status': 'processing', 'message': 'Ready to search'})
return
Translated_query = Gemma3_AI_Translate(query)
emit('search_status', {'status': 'translated', 'message': f'번역 완료: {Translated_query}'})
logger.info(f"Query: Original query : {query}")
logger.info(f"Query: Translated_query : {Translated_query}")
if selected_regulations:
Filtered_search = True
cont_selected_num = 0
# 파일로 저장
output_path = os.path.join(current_dir, "merged_ai_messages.txt")
if os.path.exists(output_path):
os.remove(output_path)
print(f"기존 파일 삭제 완료: {output_path}")
if Search_each_all_mode:
# 각각 검색 모드
emit('search_status', {'status': 'searching', 'message': f'선택된 {len(selected_regulations)}개 법규를 각각 검색 중...'})
for i, regulation in enumerate(selected_regulations):
if session_id not in active_sessions:
emit('search_cancelled', {'message': '검색이 취소되었습니다.'})
emit('search_status', {'status': 'processing', 'message': 'Ready to search'})
return
regulation_title = regulation.get('title', '')
regulation_id = regulation.get('id', '')
regulation_type = regulation.get('type', 'part') # 타입 정보 추출
emit('search_status', {
'status': 'searching_regulation',
'message': f'법규 {i+1}/{len(selected_regulations)}: [{regulation_type.upper()}] {regulation_title} 검색 중...',
'progress': (i / len(selected_regulations)) * 100
})
# 법규 타입별 필터 생성
current_filters = create_filter_by_type(regulation_type, regulation_title)
print(f"[{regulation_type}] 필터에 추가: {regulation_title}")
print(f"생성된 필터: {current_filters}")
Rag_Results = search_DB_from_multiple_regions(Translated_query, regions, region_rag_objects, current_filters)
if session_id not in active_sessions:
emit('search_cancelled', {'message': '검색이 취소되었습니다.'})
emit('search_status', {'status': 'processing', 'message': 'Ready to search'})
return
emit('search_status', {
'status': 'ai_processing',
'message': f'AI가 [{regulation_type.upper()}] {regulation_title}에 대한 답변을 생성 중...'
})
AImessage = RegAI(query, Rag_Results, ResultFile_FolderAddress)
logger.info(f"Answer: {AImessage}")
if session_id not in active_sessions:
emit('search_cancelled', {'message': '검색이 취소되었습니다.'})
return
# 각 법규별 결과를 실시간으로 전송 (타입 정보 포함)
emit('regulation_result', {
'regulation_title': f"[{regulation_type.upper()}] {regulation_title}",
'regulation_index': i + 1,
'total_regulations': len(selected_regulations),
'regulation_type': regulation_type,
'result': AImessage
})
# 파일에 저장
if isinstance(AImessage, str) and AImessage.strip():
with open(output_path, "a", encoding="utf-8") as f:
cont_selected_num += 1
from datetime import datetime
stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
f.write(f"\n--- [{stamp}] message #{cont_selected_num} --- Regulation Type: {regulation_type} --- Regulation Name : {regulation_title} ---\n {AImessage}")
emit('search_complete', {'status': 'completed', 'message': '모든 법규 검색이 완료되었습니다.'})
else:
# 통합 검색 모드 - 타입별로 그룹화
grouped_regulations = group_regulations_by_type(selected_regulations)
emit('search_status', {'status': 'searching', 'message': f'선택된 {len(selected_regulations)}개 법규를 타입별로 통합하여 검색 중...'})
# 타입별로 필터 생성
combined_filters = create_combined_filters(grouped_regulations)
logger.info(f"통합 필터: {combined_filters}")
Rag_Results = search_DB_from_multiple_regions(Translated_query, regions, region_rag_objects, combined_filters)
if session_id in active_sessions:
emit('search_status', {'status': 'ai_processing', 'message': 'AI가 통합 답변을 생성 중...'})
AImessage = RegAI(query, Rag_Results, ResultFile_FolderAddress)
logger.info(f"Answer: {AImessage}")
if session_id in active_sessions:
emit('search_result', {'result': AImessage})
emit('search_complete', {'status': 'completed', 'message': '통합 검색이 완료되었습니다.'})
else:
Filtered_search = False
emit('search_status', {'status': 'searching_all', 'message': '전체 법규에서 검색 중...'})
# 필터 없이 검색
Rag_Results = search_DB_from_multiple_regions(Translated_query, regions, region_rag_objects, None)
if session_id in active_sessions:
emit('search_status', {'status': 'ai_processing', 'message': 'AI가 답변을 생성 중...'})
AImessage = RegAI(query, Rag_Results, ResultFile_FolderAddress)
logger.info(f"Answer: {AImessage}")
if session_id in active_sessions:
emit('search_result', {'result': AImessage})
emit('search_complete', {'status': 'completed', 'message': '검색이 완료되었습니다.'})
except Exception as e:
print(f"검색 오류: {e}")
emit('search_error', {'error': str(e), 'message': '검색 중 오류가 발생했습니다.'})
finally:
# 세션 정리
if session_id in active_sessions:
del active_sessions[session_id]
@socketio.on('cancel_search')
def handle_cancel_search(data):
session_id = data.get('session_id')
if session_id and session_id in active_sessions:
del active_sessions[session_id]
emit('search_cancelled', {'message': '검색이 취소되었습니다.'})
# --- 법규 리스트 ---
@app.route('/get_reg_list', methods=['POST'])
def get_reg_list():
data = request.get_json()
selected_regions = data.get('regions', []) or ["국내", "북미", "유럽"]
all_reg_list_part = []
for region in selected_regions:
rag = region_rag_objects.get(region)
if not rag:
continue
try:
conn = rag["sqlite_conn"]
parts = reg_embedding_system.get_unique_metadata_values(conn, "regulation_part")
all_reg_list_part.extend(parts)
except Exception as e:
logger.info(f"[{region}] 법규 로드 실패: {e}")
unique_parts = sorted(set(all_reg_list_part), key=reg_embedding_system.natural_sort_key)
return jsonify(reg_list_part="\n".join(unique_parts))
# --- SocketIO ---
@socketio.on('connect')
def handle_connect():
global connected_clients
connected_clients += 1
# 클라이언트 IP 가져오기
client_ip = request.remote_addr
# 프록시(Nginx, Cloudflare 등)를 거치는 경우 실제 IP는 헤더에 들어있을 수 있음
if request.headers.get('X-Forwarded-For'):
# X-Forwarded-For 는 "client, proxy1, proxy2" 형태로 여러 IP가 있을 수 있음
client_ip = request.headers.get('X-Forwarded-For').split(',')[0].strip()
elif request.headers.get('X-Real-IP'):
client_ip = request.headers.get('X-Real-IP')
# Cloudflare의 경우
elif request.headers.get('CF-Connecting-IP'):
client_ip = request.headers.get('CF-Connecting-IP')
logger.info(f"클라이언트 연결 | IP: {client_ip} | 현재 접속자: {connected_clients}명")
global rag_connection_status_info
socketio.emit('message', {'message': rag_connection_status_info})
@socketio.on('disconnect')
def handle_disconnect():
global connected_clients
connected_clients -= 1
logger.info(f"클라이언트 연결: {connected_clients}명")
#if connected_clients <= 0:
# cleanup_connections()
# logger.info("서버 종료")
# os._exit(0)
def cleanup_connections():
for region, rag in region_rag_objects.items():
try:
rag["sqlite_conn"].close()
logger.info(f"[{region}] DB 연결 종료")
except:
pass
# --- Together AI 분석 (SDK -> requests 직접 호출로 변경) ---
def Gemma3_AI_analysis(query_txt, content_txt):
content_txt = "\n".join(doc.page_content for doc in content_txt) if isinstance(content_txt, list) else str(content_txt)
query_txt = str(query_txt)
prompt = lexi_prompts.use_prompt(lexi_prompts.AI_system_prompt, query_txt=query_txt, content_txt=content_txt)
try:
response = client.chat.completions.create(
#model="meta-llama/Llama-4-Scout-17B-16E-Instruct", #비용 효율 측면 최고
model="moonshotai/Kimi-K2-Instruct-0905", #오픈소스 최고 성능
messages=[
{
"role": "user",
"content": prompt,
}
],
)
# 응답에서 결과 텍스트를 추출
AI_Result = response.choices[0].message.content
return AI_Result
except Exception as e:
# Together SDK의 오류는 requests.exceptions.RequestException이 아닌 다른 종류의 예외로 발생합니다.
# 따라서 일반적인 Exception으로 처리하는 것이 안전합니다.
logger.info(f"Together AI 분석 API 호출 실패: {e}")
traceback.print_exc() # traceback.logger.info_exc() 대신 일반 print_exc()를 사용하거나, logging 모듈 설정을 확인하세요.
return f"AI 분석 중 오류가 발생했습니다: {e}"
# --- Together AI 번역 (SDK -> requests 직접 호출로 변경) ---
def Gemma3_AI_Translate(query_txt):
query_txt = str(query_txt)
prompt = lexi_prompts.use_prompt(lexi_prompts.query_translator, query_txt=query_txt)
try:
response = client.chat.completions.create(
#model="meta-llama/Llama-4-Scout-17B-16E-Instruct", #비용 효율 측면 최고
model="moonshotai/Kimi-K2-Instruct-0905", #오픈소스 최고 성능
messages=[
{
"role": "user",
"content": prompt,
}
],
)
# 응답에서 결과 텍스트를 추출
AI_Result = response.choices[0].message.content
return AI_Result
except Exception as e:
# API 호출 실패 시 처리 (SDK 사용 시 일반 Exception으로 처리)
logger.info(f"Together AI 번역 API 호출 실패: {e}")
# traceback.logger.info_exc() 대신 traceback.print_exc() 사용 (권장)
# 만약 기존 로깅 시스템에서 해당 함수를 정의해 사용하고 있다면 그대로 두셔도 됩니다.
# 여기서는 표준 traceback 모듈을 사용합니다.
traceback.print_exc()
# 번역 실패 시 query_txt 변수를 반환 (기존 코드 로직 반영)
return query_txt
# --- 검색 ---
# 검색 함수 수정
def search_DB_from_multiple_regions(query, selected_regions, region_rag_objects, custom_filters=None):
global Filtered_search
global filters
if not selected_regions:
selected_regions = list(region_rag_objects.keys())
print(f"Translated Query : {query}")
# 커스텀 필터가 제공된 경우 사용
search_filters = custom_filters if custom_filters is not None else filters
# 필터가 설정되어 있는지 확인
has_filters = any(search_filters.get(key, []) for key in search_filters.keys())
print(f"사용된 검색 필터: {search_filters}")
print(f"필터 사용 여부: {has_filters}")
combined_results = []
for region in selected_regions:
rag = region_rag_objects.get(region)
if not rag:
continue
ensemble_retriever = rag["ensemble_retriever"]
vectorstore = rag["vectorstore"]
sqlite_conn = rag["sqlite_conn"]
if ensemble_retriever:
if has_filters:
results = reg_embedding_system.search_with_metadata_filter(
ensemble_retriever=ensemble_retriever,
vectorstore=vectorstore,
query=query,
k=search_document_number,
metadata_filter=search_filters,
sqlite_conn=sqlite_conn
)
else:
results = reg_embedding_system.smart_search_vectorstore(
retriever=ensemble_retriever,
query=query,
k=search_document_number,
vectorstore=vectorstore,
sqlite_conn=sqlite_conn,
enable_detailed_search=True
)
print(f"[{region}] 검색 완료: {len(results)}건")
combined_results.extend(results)
return combined_results
# --- 최종 AI ---
def RegAI(query, Rag_Results, ResultFile_FolderAddress):
gc.collect()
AI_Result = "검색 결과가 없습니다." if not Rag_Results else Gemma3_AI_analysis(query, Rag_Results)
#with open(ResultFile_FolderAddress, 'w', encoding='utf-8') as f:
# print("검색된 문서:", file=f)
# logger.info("검색된 문서:")
# for i, doc in enumerate(Rag_Results):
# print(f"문서 {i+1}: {doc.page_content[:200]}... (메타: {doc.metadata})", file=f)
# logger.info(f"문서 {i+1}: {doc.page_content[:200]}... (메타: {doc.metadata})")
# print("\n답변:", file=f)
# logger.info("\n답변:")
# print(AI_Result, file=f)
# logger.info(AI_Result)
return AI_Result
# 법규 타입별 필터 생성 함수
def create_filter_by_type(regulation_type, regulation_title):
"""법규 타입에 따라 적절한 필터 딕셔너리 생성"""
filter_dict = {
"regulation_part": [],
"regulation_section": [],
"chapter_section": [],
"jo": []
}
# 타입별 매핑
type_mapping = {
"part": "regulation_part",
"section": "regulation_section",
"chapter": "chapter_section",
"jo": "jo"
}
filter_key = type_mapping.get(regulation_type, "regulation_part")
filter_dict[filter_key].append(regulation_title)
return filter_dict
# 법규들을 타입별로 그룹화하는 함수
def group_regulations_by_type(selected_regulations):
"""선택된 법규들을 타입별로 그룹화"""
grouped = {
"part": [],
"section": [],
"chapter": [],
"jo": []
}
for regulation in selected_regulations:
regulation_type = regulation.get('type', 'part')
regulation_title = regulation.get('title', '')
if regulation_title and regulation_type in grouped:
grouped[regulation_type].append(regulation_title)
return grouped
# 통합 필터 생성 함수
def create_combined_filters(grouped_regulations):
"""그룹화된 법규들로부터 통합 필터 생성"""
filters = {
"regulation_part": grouped_regulations["part"],
"regulation_section": grouped_regulations["section"],
"chapter_section": grouped_regulations["chapter"],
"jo": grouped_regulations["jo"]
}
return filters
# --- 실행 ---
if __name__ == '__main__':
# 로컬 개발용
threading.Thread(target=load_rag_objects, daemon=True).start()
time.sleep(2)
socketio.emit('message', {'message': '데이터 로딩 시작...'})
socketio.run(app, host='0.0.0.0', port=7860, debug=False)
else:
# Gunicorn용: 워커 시작 후 로딩
import atexit
loading_thread = threading.Thread(target=load_rag_objects, daemon=True)
loading_thread.start()
atexit.register(cleanup_connections)