lexistudio / app.py
scipious's picture
Update app.py
afa43df verified
raw
history blame
10.9 kB
import os
#os.environ["PYDANTIC_V1_STYLE"] = "1"
#os.environ["PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS"] = "1"
# --------------------------------------------------------------------------
from flask import Flask, render_template, jsonify, request
from flask_socketio import SocketIO
import threading
import sqlite3
import gc
import time
import re
import traceback
import requests # API 호출을 위해 필요
# --- Together AI SDK (제거됨) ---
# from together import Together
# --- eventlet monkey patch (Gunicorn + SocketIO 필수!) ---
import eventlet
eventlet.monkey_patch()
# --- Flask & SocketIO 설정 ---
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*", async_mode='eventlet')
import logging
# 로거 설정: 레벨을 INFO로 설정하고, 포맷을 지정합니다.
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# --- 외부 모듈 임포트 ---
import reg_embedding_system
import leximind_prompts
# --- 전역 변수 ---
connected_clients = 0
search_document_number = 30
Filtered_search = False
filters = {"regulation_part": []}
# --- 경로 설정 ---
current_dir = os.path.dirname(os.path.abspath(__file__))
ResultFile_FolderAddress = os.path.join(current_dir, 'result.txt')
# --- RAG 데이터 경로 ---
# NOTE: Hugging Face Spaces에서 데이터가 /app/data에 있는지 확인해야 합니다.
region_paths = {
"국내": "/app/data/KMVSS_RAG",
"북미": "/app/data/FMVSS_RAG",
"유럽": "/app/data/EUR_RAG"
}
# --- 프롬프트 ---
lexi_prompts = leximind_prompts.PromptLibrary()
# --- RAG 객체 ---
region_rag_objects = {}
# --- Together AI 설정 (SDK 대신 API 호출에 사용) ---
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
if not TOGETHER_API_KEY:
raise EnvironmentError("TOGETHER_API_KEY가 설정되지 않았습니다. Hugging Face Secrets에 추가하세요.")
# client = Together(api_key=TOGETHER_API_KEY) # <--- Together SDK 클라이언트 제거
# --- RAG 로딩 ---
def load_rag_objects():
global region_rag_objects
# 📢 [수정]: 로딩 스레드 시작 로그를 추가하여 Gunicorn 로그에서 확인 가능하게 함
logger.info(">>> [RAG_LOADER] RAG 로딩 스레드 시작 <<<")
for region, path in region_paths.items():
if not os.path.exists(path):
msg = f"[{region}] 경로 없음: {path}"
socketio.emit('message', {'message': msg})
logger.info(msg)
continue
try:
socketio.emit('message', {'message': f"[{region}] RAG 로딩 중..."})
# NOTE: reg_embedding_system 모듈이 현재 환경에 설치/존재하는지 확인해야 합니다.
ensemble_retriever, vectorstore, sqlite_conn = reg_embedding_system.load_embedding_from_faiss(path)
sqlite_conn.close()
db_path = os.path.join(path, "metadata_mapping.db")
new_conn = sqlite3.connect(db_path, check_same_thread=False)
region_rag_objects[region] = {
"ensemble_retriever": ensemble_retriever,
"vectorstore": vectorstore,
"sqlite_conn": new_conn
}
socketio.emit('message', {'message': f"[{region}] 로딩 완료"})
logger.info(f"[{region}] RAG 로딩 완료")
except Exception as e:
error_msg = f"[{region}] 로딩 실패: {str(e)}"
logger.info(error_msg)
# 📢 [수정]: 상세한 에러 추적을 위해 traceback 추가
traceback.logger.info_exc()
socketio.emit('message', {'message': error_msg})
socketio.emit('message', {'message': "Ready to Search"})
logger.info("Ready to Search")
# --- 웹 ---
@app.route('/')
def index():
return render_template('chat.html')
# --- 메시지 ---
@app.route('/get_message', methods=['POST'])
def get_message():
global Filtered_search, filters
data = request.get_json()
query = data.get('query', '').strip()
regions = data.get('regions', [])
selected_regulations = data.get('selectedRegulations', [])
filters = {"regulation_part": []}
Filtered_search = bool(selected_regulations)
if selected_regulations:
for reg in selected_regulations:
title = reg.get('title', '')
if title:
filters["regulation_part"].append(title)
Rag_Results = search_DB_from_multiple_regions(query, regions, region_rag_objects)
AImessage = RegAI(query, Rag_Results, ResultFile_FolderAddress)
return jsonify(message=AImessage)
# --- 법규 리스트 ---
@app.route('/get_reg_list', methods=['POST'])
def get_reg_list():
data = request.get_json()
selected_regions = data.get('regions', []) or ["국내", "북미", "유럽"]
all_reg_list_part = []
for region in selected_regions:
rag = region_rag_objects.get(region)
if not rag:
continue
try:
conn = rag["sqlite_conn"]
parts = reg_embedding_system.get_unique_metadata_values(conn, "regulation_part")
all_reg_list_part.extend(parts)
except Exception as e:
logger.info(f"[{region}] 법규 로드 실패: {e}")
unique_parts = sorted(set(all_reg_list_part), key=reg_embedding_system.natural_sort_key)
return jsonify(reg_list_part="\n".join(unique_parts))
# --- SocketIO ---
@socketio.on('connect')
def handle_connect():
global connected_clients
connected_clients += 1
logger.info(f"클라이언트 연결: {connected_clients}명")
@socketio.on('disconnect')
def handle_disconnect():
global connected_clients
connected_clients -= 1
logger.info(f"연결 해제: {connected_clients}명")
if connected_clients <= 0:
cleanup_connections()
logger.info("서버 종료")
os._exit(0)
def cleanup_connections():
for region, rag in region_rag_objects.items():
try:
rag["sqlite_conn"].close()
logger.info(f"[{region}] DB 연결 종료")
except:
pass
# --- Together AI 분석 (SDK -> requests 직접 호출로 변경) ---
def Gemma3_AI_analysis(query_txt, content_txt):
content_txt = "\n".join(doc.page_content for doc in content_txt) if isinstance(content_txt, list) else str(content_txt)
query_txt = str(query_txt)
prompt = lexi_prompts.use_prompt(lexi_prompts.AI_system_prompt, query_txt=query_txt, content_txt=content_txt)
headers = {
"Authorization": f"Bearer {TOGETHER_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 1024,
"temperature": 0.7
}
try:
response = requests.post("https://api.together.xyz/v1/chat/completions", headers=headers, json=payload, timeout=120)
response.raise_for_status() # HTTP 오류가 발생하면 예외 발생
data = response.json()
return data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
logger.info(f"Together AI 분석 API 호출 실패: {e}")
traceback.logger.info_exc()
return f"AI 분석 중 오류가 발생했습니다: {e}"
# --- Together AI 번역 (SDK -> requests 직접 호출로 변경) ---
def Gemma3_AI_Translate(query_txt):
query_txt = str(query_txt)
prompt = lexi_prompts.use_prompt(lexi_prompts.query_translator, query_txt=query_txt)
headers = {
"Authorization": f"Bearer {TOGETHER_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "meta-llama/Llama-3.2-3B-Instruct-Turbo",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": 512,
"temperature": 0.3
}
try:
response = requests.post("https://api.together.xyz/v1/chat/completions", headers=headers, json=payload, timeout=60)
response.raise_for_status() # HTTP 오류가 발생하면 예외 발생
data = response.json()
return data["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
logger.info(f"Together AI 번역 API 호출 실패: {e}")
traceback.logger.info_exc()
return query_txt # 번역 실패 시 원래 쿼리를 사용 (최소한의 기능 유지)
# --- 검색 ---
def search_DB_from_multiple_regions(query, selected_regions, region_rag_objects):
selected_regions = selected_regions or list(region_rag_objects.keys())
query = Gemma3_AI_Translate(query)
logger.info(f"번역된 쿼리: {query}")
combined_results = []
for region in selected_regions:
rag = region_rag_objects.get(region)
if not rag:
continue
retriever = rag["ensemble_retriever"]
vectorstore = rag["vectorstore"]
sqlite_conn = rag["sqlite_conn"]
if Filtered_search:
results = reg_embedding_system.search_with_metadata_filter(
ensemble_retriever=retriever,
vectorstore=vectorstore,
query=query,
k=search_document_number,
metadata_filter=filters,
sqlite_conn=sqlite_conn
)
else:
results = reg_embedding_system.smart_search_vectorstore(
retriever=retriever,
query=query,
k=search_document_number,
vectorstore=vectorstore,
sqlite_conn=sqlite_conn,
enable_detailed_search=True
)
logger.info(f"[{region}] 검색: {len(results)}건")
combined_results.extend(results)
return combined_results
# --- 최종 AI ---
def RegAI(query, Rag_Results, ResultFile_FolderAddress):
gc.collect()
AI_Result = "검색 결과가 없습니다." if not Rag_Results else Gemma3_AI_analysis(query, Rag_Results)
with open(ResultFile_FolderAddress, 'w', encoding='utf-8') as f:
logger.info("검색된 문서:", file=f)
for i, doc in enumerate(Rag_Results):
logger.info(f"문서 {i+1}: {doc.page_content[:200]}... (메타: {doc.metadata})", file=f)
logger.info("\n답변:", file=f)
logger.info(AI_Result, file=f)
return AI_Result
# --- 실행 ---
if __name__ == '__main__':
# 로컬 개발용
threading.Thread(target=load_rag_objects, daemon=True).start()
time.sleep(2)
socketio.emit('message', {'message': '데이터 로딩 시작...'})
socketio.run(app, host='0.0.0.0', port=7860, debug=False)
else:
# Gunicorn용: 워커 시작 후 로딩
import atexit
loading_thread = threading.Thread(target=load_rag_objects, daemon=True)
loading_thread.start()
atexit.register(cleanup_connections)