jeongsoo's picture
Upload fix
e234e39
"""
RAG 검색 챗봇 μ›Ή μ• ν”Œλ¦¬μΌ€μ΄μ…˜
"""
import os
import json
import logging
import tempfile
import threading
from flask import Flask, request, jsonify, render_template, send_from_directory
from werkzeug.utils import secure_filename
from dotenv import load_dotenv
# 둜거 μ„€μ •
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)
logger = logging.getLogger(__name__)
# ν™˜κ²½ λ³€μˆ˜ λ‘œλ“œ
load_dotenv()
# 둜컬 λͺ¨λ“ˆ μž„ν¬νŠΈ
from utils.vito_stt import VitoSTT
from utils.llm_client import DeepSeekLLM
from utils.document_processor import DocumentProcessor
from retrieval.vector_retriever import VectorRetriever
from retrieval.reranker import ReRanker
# Flask μ•± μ΄ˆκΈ°ν™”
app = Flask(__name__)
# μ΅œλŒ€ 파일 크기 μ„€μ • (10MB)
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024
app.config['UPLOAD_FOLDER'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
app.config['DATA_FOLDER'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')
app.config['INDEX_PATH'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data', 'index')
# μ—…λ‘œλ“œ 폴더가 μ—†μœΌλ©΄ 생성
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['DATA_FOLDER'], exist_ok=True)
os.makedirs(app.config['INDEX_PATH'], exist_ok=True)
# ν—ˆμš©λ˜λŠ” μ˜€λ””μ˜€ 파일 ν™•μž₯자
ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a'}
# ν—ˆμš©λ˜λŠ” λ¬Έμ„œ 파일 ν™•μž₯자
ALLOWED_DOC_EXTENSIONS = {'txt', 'md', 'pdf', 'docx', 'csv'}
# DeepSeek LLM ν΄λΌμ΄μ–ΈνŠΈ μ΄ˆκΈ°ν™”
llm_client = DeepSeekLLM()
# VITO STT ν΄λΌμ΄μ–ΈνŠΈ μ΄ˆκΈ°ν™”
stt_client = VitoSTT()
# μ „μ—­ 검색기 객체와 μž¬μˆœμœ„ν™” 검색기 객체
base_retriever = None
retriever = None
# μ•± μ΄ˆκΈ°ν™” μƒνƒœ
app_ready = False
def allowed_audio_file(filename):
"""파일이 ν—ˆμš©λœ μ˜€λ””μ˜€ ν™•μž₯자λ₯Ό κ°€μ§€λŠ”μ§€ 확인"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_AUDIO_EXTENSIONS
def allowed_doc_file(filename):
"""파일이 ν—ˆμš©λœ λ¬Έμ„œ ν™•μž₯자λ₯Ό κ°€μ§€λŠ”μ§€ 확인"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_DOC_EXTENSIONS
def init_retriever():
"""검색기 객체 μ΄ˆκΈ°ν™” λ˜λŠ” λ‘œλ“œ"""
global base_retriever, retriever
index_path = app.config['INDEX_PATH']
# κΈ°μ‘΄ μΈλ±μŠ€κ°€ μžˆλŠ”μ§€ 확인
if os.path.exists(os.path.join(index_path, "documents.json")):
try:
logger.info(f"κΈ°μ‘΄ 벑터 인덱슀λ₯Ό '{index_path}'μ—μ„œ λ‘œλ“œν•©λ‹ˆλ‹€...")
base_retriever = VectorRetriever.load(index_path)
logger.info(f"{len(base_retriever.documents)}개 λ¬Έμ„œκ°€ λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
except Exception as e:
logger.error(f"인덱슀 λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
logger.info("μƒˆ 검색기λ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€...")
base_retriever = VectorRetriever()
else:
logger.info("κΈ°μ‘΄ 인덱슀λ₯Ό 찾을 수 μ—†μ–΄ μƒˆ 검색기λ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€...")
base_retriever = VectorRetriever()
# 데이터 ν΄λ”μ˜ λ¬Έμ„œ λ‘œλ“œ
data_path = app.config['DATA_FOLDER']
if not base_retriever.documents and os.path.exists(data_path):
logger.info(f"{data_path}μ—μ„œ λ¬Έμ„œλ₯Ό λ‘œλ“œν•©λ‹ˆλ‹€...")
docs = DocumentProcessor.load_documents_from_directory(
data_path,
extensions=[".txt", ".md", ".csv"],
recursive=True
)
if docs:
logger.info(f"{len(docs)}개 λ¬Έμ„œλ₯Ό 검색기에 μΆ”κ°€ν•©λ‹ˆλ‹€...")
base_retriever.add_documents(docs)
# 인덱슀 μ €μž₯
logger.info(f"검색기 μƒνƒœλ₯Ό '{index_path}'에 μ €μž₯ν•©λ‹ˆλ‹€...")
try:
base_retriever.save(index_path)
logger.info("인덱슀 μ €μž₯ μ™„λ£Œ")
except Exception as e:
logger.error(f"인덱슀 μ €μž₯ 쀑 였λ₯˜ λ°œμƒ: {e}")
# μž¬μˆœμœ„ν™” 검색기 μ΄ˆκΈ°ν™”
logger.info("μž¬μˆœμœ„ν™” 검색기λ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€...")
# 자체 κ΅¬ν˜„λœ μž¬μˆœμœ„ν™” ν•¨μˆ˜ - κ°„λ‹¨ν•œ TF-IDF 기반 점수 μž¬κ³„μ‚°
def custom_rerank_fn(query, results):
"""κ°„λ‹¨ν•œ 자체 κ΅¬ν˜„ μž¬μˆœμœ„ν™” ν•¨μˆ˜"""
# 쿼리 단어 뢄석
query_terms = set(query.lower().split())
# κ²°κ³Ό μž¬μ μˆ˜ν™”
for result in results:
if "text" in result:
text = result["text"].lower()
# κ°„λ‹¨ν•œ TF 기반 점수 계산
term_freq = sum(1 for term in query_terms if term in text)
# 길이 μ •κ·œν™”
normalized_score = term_freq / (len(text.split()) + 1) * 10
# κΈ°μ‘΄ μž„λ² λ”© μ μˆ˜μ™€ μƒˆ 점수 κ²°ν•©
result["rerank_score"] = result.get("score", 0) * 0.7 + normalized_score * 0.3
else:
# ν…μŠ€νŠΈκ°€ μ—†λŠ” 경우 μ›λž˜ 점수 μœ μ§€
result["rerank_score"] = result.get("score", 0)
# μž¬μ μˆ˜ν™”λœ κ²°κ³Ό μ •λ ¬
results.sort(key=lambda x: x.get("rerank_score", 0), reverse=True)
return results
# μž¬μˆœμœ„ν™” 검색기 객체 생성 (CrossEncoder λŒ€μ‹  μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜ μ‚¬μš©)
retriever = ReRanker(
base_retriever=base_retriever,
rerank_fn=custom_rerank_fn,
rerank_field="text"
)
logger.info("μž¬μˆœμœ„ν™” 검색기 μ΄ˆκΈ°ν™” μ™„λ£Œ")
return retriever
# 비동기 μ΄ˆκΈ°ν™” ν•¨μˆ˜
def background_init():
"""λ°±κ·ΈλΌμš΄λ“œμ—μ„œ 검색기 μ΄ˆκΈ°ν™” μˆ˜ν–‰"""
global app_ready, retriever
try:
logger.info("λ°±κ·ΈλΌμš΄λ“œ μ΄ˆκΈ°ν™” μ‹œμž‘")
retriever = init_retriever()
app_ready = True
logger.info("μ•± μ΄ˆκΈ°ν™” μ™„λ£Œ")
except Exception as e:
logger.error(f"μ•± μ΄ˆκΈ°ν™” 쀑 였λ₯˜ λ°œμƒ: {e}", exc_info=True)
app_ready = False
# λ°±κ·ΈλΌμš΄λ“œ μŠ€λ ˆλ“œμ—μ„œ μ΄ˆκΈ°ν™” μ‹œμž‘
init_thread = threading.Thread(target=background_init)
init_thread.daemon = True
init_thread.start()
@app.route('/')
def index():
"""메인 νŽ˜μ΄μ§€"""
if not app_ready:
return render_template('loading.html')
return render_template('index.html')
@app.route('/api/status')
def app_status():
"""μ•± μ΄ˆκΈ°ν™” μƒνƒœ 확인 API"""
return jsonify({"ready": app_ready})
@app.route('/api/chat', methods=['POST'])
def chat():
"""ν…μŠ€νŠΈ 기반 챗봇 API"""
global retriever, app_ready
# μ•± μ€€λΉ„ μƒνƒœ 확인
if not app_ready:
return jsonify({"error": "앱이 아직 μ΄ˆκΈ°ν™” μ€‘μž…λ‹ˆλ‹€. μž μ‹œ ν›„ λ‹€μ‹œ μ‹œλ„ν•΄μ£Όμ„Έμš”."}), 503
try:
data = request.get_json()
if not data or 'query' not in data:
return jsonify({"error": "쿼리가 μ œκ³΅λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."}), 400
query = data['query']
logger.info(f"쿼리 μˆ˜μ‹ : {query}")
# RAG 검색 μˆ˜ν–‰ (μž¬μˆœμœ„ν™” 적용)
search_results = retriever.search(query, top_k=5, first_stage_k=20)
# 검색 κ²°κ³Όμ—μ„œ μ»¨ν…μŠ€νŠΈ μΆ”μΆœ
context = DocumentProcessor.prepare_rag_context(search_results, field="text")
if not context:
logger.warning("검색 κ²°κ³Όκ°€ μ—†μŠ΅λ‹ˆλ‹€.")
return jsonify({
"answer": "μ£„μ†‘ν•©λ‹ˆλ‹€. κ΄€λ ¨ 정보λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.",
"sources": []
})
# LLM에 질의
answer = llm_client.rag_generate(query, context)
# μ†ŒμŠ€ 정보 μΆ”μΆœ
sources = []
for result in search_results:
if "source" in result:
source_info = {
"source": result.get("source", "Unknown"),
"score": result.get("rerank_score", result.get("score", 0))
}
sources.append(source_info)
return jsonify({
"answer": answer,
"sources": sources
})
except Exception as e:
logger.error(f"μ±„νŒ… 처리 쀑 였λ₯˜ λ°œμƒ: {e}", exc_info=True)
return jsonify({"error": f"처리 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}"}), 500
@app.route('/api/voice', methods=['POST'])
def voice_chat():
"""
μŒμ„± μ±— API μ—”λ“œν¬μΈνŠΈ: μ˜€λ””μ˜€ νŒŒμΌμ„ λ°›μ•„ ν…μŠ€νŠΈλ‘œ λ³€ν™˜ν•˜κ³ , μ§ˆλ¬Έμ— λŒ€ν•œ 응닡과 μ†ŒμŠ€λ₯Ό λ°˜ν™˜
Returns:
JSON 응닡:
- transcription: μΈμ‹λœ ν…μŠ€νŠΈ
- answer: LLMμ—μ„œ μƒμ„±ν•œ 응닡
- sources: κ²€μƒ‰λœ λ¬Έμ„œ μ†ŒμŠ€ (리슀트)
- error: 였λ₯˜ λ°œμƒ μ‹œ 였λ₯˜ λ©”μ‹œμ§€
- details: 였λ₯˜ 상세 정보 (선택적)
"""
logger.info("μŒμ„± μ±— μš”μ²­ μˆ˜μ‹ ")
# μ˜€λ””μ˜€ 파일 확인
if 'audio' not in request.files:
logger.error("μ˜€λ””μ˜€ 파일이 μ œκ³΅λ˜μ§€ μ•ŠμŒ")
return jsonify({"error": "μ˜€λ””μ˜€ 파일이 μ œκ³΅λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."}), 400
audio_file = request.files['audio']
logger.info(f"μˆ˜μ‹ λœ 파일: {audio_file.filename}")
try:
# μ˜€λ””μ˜€ 파일 읽기
with audio_file.stream as f:
audio_bytes = f.read()
# μŒμ„±μΈμ‹ (VitoSTT)
stt = VitoSTT()
stt_result = stt.transcribe_audio(audio_bytes, language="ko")
if not stt_result["success"]:
logger.error(f"μŒμ„±μΈμ‹ μ‹€νŒ¨: {stt_result['error']}")
return jsonify({
"error": stt_result["error"],
"details": stt_result.get("details", "")
}), 500
transcription = stt_result["text"]
if not transcription:
logger.warning("μŒμ„±μΈμ‹ κ²°κ³Όκ°€ λΉ„μ–΄μžˆμŠ΅λ‹ˆλ‹€.")
return jsonify({"error": "μŒμ„±μ—μ„œ ν…μŠ€νŠΈλ₯Ό μΈμ‹ν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€."}), 400
logger.info(f"μŒμ„±μΈμ‹ 성곡: {transcription[:50]}...")
# 검색기 호좜: μΈμ‹λœ ν…μŠ€νŠΈλ₯Ό 쿼리둜 μ‚¬μš©
sources = retriever.search(transcription)
if not sources:
logger.warning("κ²€μƒ‰λœ μ†ŒμŠ€κ°€ μ—†μŠ΅λ‹ˆλ‹€.")
sources = []
# μ†ŒμŠ€ λ¬Έμ„œ λ‚΄μš©μ„ μ»¨ν…μŠ€νŠΈλ‘œ μ€€λΉ„
context = "\n".join([doc["text"] for doc in sources])
logger.info(f"κ²€μƒ‰λœ μ†ŒμŠ€ 수: {len(sources)}")
# LLM 호좜: 질문과 μ»¨ν…μŠ€νŠΈλ₯Ό λ°”νƒ•μœΌλ‘œ 응닡 생성
prompt = f"질문: {transcription}\n\nμ»¨ν…μŠ€νŠΈ:\n{context}\n\nλ‹΅λ³€:"
answer = llm_client.generate(prompt)
if not answer:
logger.error("LLM 응닡 생성 μ‹€νŒ¨")
return jsonify({"error": "응닡 생성에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€."}), 500
logger.info(f"LLM 응닡 생성 성곡: {answer[:50]}...")
# 응닡 λ°˜ν™˜
return jsonify({
"transcription": transcription,
"answer": answer,
"sources": sources # [{ "text": "...", "metadata": {...} }, ...]
})
except Exception as e:
logger.error(f"μŒμ„± μ±— 처리 쀑 였λ₯˜ λ°œμƒ: {str(e)}", exc_info=True)
return jsonify({
"error": "μŒμ„± 처리 쀑 λ‚΄λΆ€ 였λ₯˜ λ°œμƒ",
"details": str(e)
}), 500
@app.route('/api/upload', methods=['POST'])
def upload_document():
"""μ§€μ‹λ² μ΄μŠ€ λ¬Έμ„œ μ—…λ‘œλ“œ API"""
global base_retriever, retriever, app_ready
# μ•± μ€€λΉ„ μƒνƒœ 확인
if not app_ready:
return jsonify({"error": "앱이 아직 μ΄ˆκΈ°ν™” μ€‘μž…λ‹ˆλ‹€. μž μ‹œ ν›„ λ‹€μ‹œ μ‹œλ„ν•΄μ£Όμ„Έμš”."}), 503
try:
# 파일이 μš”μ²­μ— ν¬ν•¨λ˜μ–΄ μžˆλŠ”μ§€ 확인
if 'document' not in request.files:
return jsonify({"error": "λ¬Έμ„œ 파일이 μ œκ³΅λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."}), 400
doc_file = request.files['document']
logger.info(f"받은 파일λͺ…: {doc_file.filename}")
# 파일λͺ…이 λΉ„μ–΄μžˆλŠ”μ§€ 확인
if doc_file.filename == '':
return jsonify({"error": "μ„ νƒλœ 파일이 μ—†μŠ΅λ‹ˆλ‹€."}), 400
# 파일 ν˜•μ‹ 확인
if not allowed_doc_file(doc_file.filename):
logger.error(f"ν—ˆμš©λ˜μ§€ μ•ŠλŠ” 파일 ν˜•μ‹: {doc_file.filename}")
return jsonify({"error": "ν—ˆμš©λ˜μ§€ μ•ŠλŠ” 파일 ν˜•μ‹μž…λ‹ˆλ‹€. ν˜„μž¬ ν—ˆμš©λœ 파일 ν˜•μ‹: {}".format(', '.join(ALLOWED_DOC_EXTENSIONS))}), 400
# 파일λͺ… λ³΄μ•ˆ 처리
filename = secure_filename(doc_file.filename)
# 데이터 폴더에 μ €μž₯
filepath = os.path.join(app.config['DATA_FOLDER'], filename)
doc_file.save(filepath)
logger.info(f"λ¬Έμ„œκ°€ μ €μž₯λ˜μ—ˆμŠ΅λ‹ˆλ‹€: {filepath}")
# λ¬Έμ„œ 처리
try:
# λ¨Όμ € UTF-8둜 μ‹œλ„
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
except UnicodeDecodeError:
# UTF-8둜 μ‹€νŒ¨ν•˜λ©΄ CP949(ν•œκ΅­μ–΄ Windows κΈ°λ³Έ 인코딩)둜 μ‹œλ„
logger.info(f"UTF-8 λ””μ½”λ”© μ‹€νŒ¨, CP949둜 μ‹œλ„: {filename}")
with open(filepath, 'r', encoding='cp949') as f:
content = f.read()
# 메타데이터 생성
metadata = {
"source": filename,
"filename": filename,
"filetype": filename.rsplit('.', 1)[1].lower(),
"filepath": filepath
}
# λ¬Έμ„œ 청크 생성
docs = DocumentProcessor.text_to_documents(
content,
metadata=metadata,
chunk_size=512,
chunk_overlap=50
)
if docs:
logger.info(f"{len(docs)}개 λ¬Έμ„œ 청크λ₯Ό 검색기에 μΆ”κ°€ν•©λ‹ˆλ‹€...")
base_retriever.add_documents(docs)
# 인덱슀 μ €μž₯
logger.info(f"검색기 μƒνƒœλ₯Ό μ €μž₯ν•©λ‹ˆλ‹€...")
index_path = app.config['INDEX_PATH']
try:
base_retriever.save(index_path)
logger.info("인덱슀 μ €μž₯ μ™„λ£Œ")
except Exception as e:
logger.error(f"인덱슀 μ €μž₯ 쀑 였λ₯˜ λ°œμƒ: {e}")
return jsonify({"error": f"인덱슀 μ €μž₯ 쀑 였λ₯˜: {str(e)}"}), 500
return jsonify({
"success": True,
"message": f"파일 '{filename}'κ°€ μ„±κ³΅μ μœΌλ‘œ μ—…λ‘œλ“œλ˜κ³  {len(docs)}개 청크가 μΆ”κ°€λ˜μ—ˆμŠ΅λ‹ˆλ‹€."
})
else:
logger.warning(f"파일 '{filename}'μ—μ„œ μ²˜λ¦¬ν•  λ¬Έμ„œκ°€ μ—†μŠ΅λ‹ˆλ‹€.")
return jsonify({
"warning": True,
"message": f"파일 '{filename}'이 μ €μž₯λ˜μ—ˆμ§€λ§Œ μ²˜λ¦¬ν•  λ‚΄μš©μ΄ μ—†μŠ΅λ‹ˆλ‹€."
})
except Exception as e:
logger.error(f"λ¬Έμ„œ '{filename}' 처리 쀑 였λ₯˜ λ°œμƒ: {e}", exc_info=True)
return jsonify({"error": f"λ¬Έμ„œ 처리 쀑 였λ₯˜: {str(e)}"}), 500
except Exception as e:
logger.error(f"파일 μ—…λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}", exc_info=True)
return jsonify({"error": f"파일 μ—…λ‘œλ“œ 쀑 였λ₯˜: {str(e)}"}), 500
@app.route('/api/documents', methods=['GET'])
def list_documents():
"""μ§€μ‹λ² μ΄μŠ€ λ¬Έμ„œ λͺ©λ‘ API"""
global base_retriever, retriever, app_ready
# μ•± μ€€λΉ„ μƒνƒœ 확인
if not app_ready:
return jsonify({"error": "앱이 아직 μ΄ˆκΈ°ν™” μ€‘μž…λ‹ˆλ‹€. μž μ‹œ ν›„ λ‹€μ‹œ μ‹œλ„ν•΄μ£Όμ„Έμš”."}), 503
try:
# λ¬Έμ„œ μ†ŒμŠ€ λͺ©λ‘ 생성
sources = {}
if base_retriever and base_retriever.documents:
for doc in base_retriever.documents:
source = doc.get("source", "unknown")
if source in sources:
sources[source]["chunks"] += 1
else:
sources[source] = {
"filename": doc.get("filename", source),
"chunks": 1,
"filetype": doc.get("filetype", "unknown")
}
# λͺ©λ‘ ν˜•μ‹μœΌλ‘œ λ³€ν™˜
documents = []
for source, info in sources.items():
documents.append({
"source": source,
"filename": info["filename"],
"chunks": info["chunks"],
"filetype": info["filetype"]
})
# 청크 수둜 μ •λ ¬
documents.sort(key=lambda x: x["chunks"], reverse=True)
return jsonify({
"documents": documents,
"total_documents": len(documents),
"total_chunks": sum(doc["chunks"] for doc in documents)
})
except Exception as e:
logger.error(f"λ¬Έμ„œ λͺ©λ‘ 쑰회 쀑 였λ₯˜ λ°œμƒ: {e}", exc_info=True)
return jsonify({"error": f"λ¬Έμ„œ λͺ©λ‘ 쑰회 쀑 였λ₯˜: {str(e)}"}), 500
# 정적 파일 μ„œλΉ™
@app.route('/static/<path:path>')
def send_static(path):
return send_from_directory('static', path)
if __name__ == '__main__':
app.run(debug=False, host='0.0.0.0', port=5000)