|
|
""" |
|
|
RAG κ²μ μ±λ΄ μΉ μ ν리μΌμ΄μ
|
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import logging |
|
|
import tempfile |
|
|
import threading |
|
|
from flask import Flask, request, jsonify, render_template, send_from_directory |
|
|
from werkzeug.utils import secure_filename |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
level=logging.INFO |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
from utils.vito_stt import VitoSTT |
|
|
from utils.llm_client import DeepSeekLLM |
|
|
from utils.document_processor import DocumentProcessor |
|
|
from retrieval.vector_retriever import VectorRetriever |
|
|
from retrieval.reranker import ReRanker |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024 |
|
|
app.config['UPLOAD_FOLDER'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') |
|
|
app.config['DATA_FOLDER'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data') |
|
|
app.config['INDEX_PATH'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data', 'index') |
|
|
|
|
|
|
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
|
|
os.makedirs(app.config['DATA_FOLDER'], exist_ok=True) |
|
|
os.makedirs(app.config['INDEX_PATH'], exist_ok=True) |
|
|
|
|
|
|
|
|
ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a'} |
|
|
|
|
|
|
|
|
ALLOWED_DOC_EXTENSIONS = {'txt', 'md', 'pdf', 'docx', 'csv'} |
|
|
|
|
|
|
|
|
llm_client = DeepSeekLLM() |
|
|
|
|
|
|
|
|
stt_client = VitoSTT() |
|
|
|
|
|
|
|
|
base_retriever = None |
|
|
retriever = None |
|
|
|
|
|
|
|
|
app_ready = False |
|
|
|
|
|
def allowed_audio_file(filename): |
|
|
"""νμΌμ΄ νμ©λ μ€λμ€ νμ₯μλ₯Ό κ°μ§λμ§ νμΈ""" |
|
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_AUDIO_EXTENSIONS |
|
|
|
|
|
def allowed_doc_file(filename): |
|
|
"""νμΌμ΄ νμ©λ λ¬Έμ νμ₯μλ₯Ό κ°μ§λμ§ νμΈ""" |
|
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_DOC_EXTENSIONS |
|
|
|
|
|
def init_retriever(): |
|
|
"""κ²μκΈ° κ°μ²΄ μ΄κΈ°ν λλ λ‘λ""" |
|
|
global base_retriever, retriever |
|
|
|
|
|
index_path = app.config['INDEX_PATH'] |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(index_path, "documents.json")): |
|
|
try: |
|
|
logger.info(f"κΈ°μ‘΄ λ²‘ν° μΈλ±μ€λ₯Ό '{index_path}'μμ λ‘λν©λλ€...") |
|
|
base_retriever = VectorRetriever.load(index_path) |
|
|
logger.info(f"{len(base_retriever.documents)}κ° λ¬Έμκ° λ‘λλμμ΅λλ€.") |
|
|
except Exception as e: |
|
|
logger.error(f"μΈλ±μ€ λ‘λ μ€ μ€λ₯ λ°μ: {e}") |
|
|
logger.info("μ κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€...") |
|
|
base_retriever = VectorRetriever() |
|
|
else: |
|
|
logger.info("κΈ°μ‘΄ μΈλ±μ€λ₯Ό μ°Ύμ μ μμ΄ μ κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€...") |
|
|
base_retriever = VectorRetriever() |
|
|
|
|
|
|
|
|
data_path = app.config['DATA_FOLDER'] |
|
|
if not base_retriever.documents and os.path.exists(data_path): |
|
|
logger.info(f"{data_path}μμ λ¬Έμλ₯Ό λ‘λν©λλ€...") |
|
|
docs = DocumentProcessor.load_documents_from_directory( |
|
|
data_path, |
|
|
extensions=[".txt", ".md", ".csv"], |
|
|
recursive=True |
|
|
) |
|
|
|
|
|
if docs: |
|
|
logger.info(f"{len(docs)}κ° λ¬Έμλ₯Ό κ²μκΈ°μ μΆκ°ν©λλ€...") |
|
|
base_retriever.add_documents(docs) |
|
|
|
|
|
|
|
|
logger.info(f"κ²μκΈ° μνλ₯Ό '{index_path}'μ μ μ₯ν©λλ€...") |
|
|
try: |
|
|
base_retriever.save(index_path) |
|
|
logger.info("μΈλ±μ€ μ μ₯ μλ£") |
|
|
except Exception as e: |
|
|
logger.error(f"μΈλ±μ€ μ μ₯ μ€ μ€λ₯ λ°μ: {e}") |
|
|
|
|
|
|
|
|
logger.info("μ¬μμν κ²μκΈ°λ₯Ό μ΄κΈ°νν©λλ€...") |
|
|
|
|
|
|
|
|
def custom_rerank_fn(query, results): |
|
|
"""κ°λ¨ν μ체 ꡬν μ¬μμν ν¨μ""" |
|
|
|
|
|
query_terms = set(query.lower().split()) |
|
|
|
|
|
|
|
|
for result in results: |
|
|
if "text" in result: |
|
|
text = result["text"].lower() |
|
|
|
|
|
|
|
|
term_freq = sum(1 for term in query_terms if term in text) |
|
|
|
|
|
|
|
|
normalized_score = term_freq / (len(text.split()) + 1) * 10 |
|
|
|
|
|
|
|
|
result["rerank_score"] = result.get("score", 0) * 0.7 + normalized_score * 0.3 |
|
|
else: |
|
|
|
|
|
result["rerank_score"] = result.get("score", 0) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x.get("rerank_score", 0), reverse=True) |
|
|
return results |
|
|
|
|
|
|
|
|
retriever = ReRanker( |
|
|
base_retriever=base_retriever, |
|
|
rerank_fn=custom_rerank_fn, |
|
|
rerank_field="text" |
|
|
) |
|
|
|
|
|
logger.info("μ¬μμν κ²μκΈ° μ΄κΈ°ν μλ£") |
|
|
|
|
|
return retriever |
|
|
|
|
|
|
|
|
def background_init(): |
|
|
"""λ°±κ·ΈλΌμ΄λμμ κ²μκΈ° μ΄κΈ°ν μν""" |
|
|
global app_ready, retriever |
|
|
try: |
|
|
logger.info("λ°±κ·ΈλΌμ΄λ μ΄κΈ°ν μμ") |
|
|
retriever = init_retriever() |
|
|
app_ready = True |
|
|
logger.info("μ± μ΄κΈ°ν μλ£") |
|
|
except Exception as e: |
|
|
logger.error(f"μ± μ΄κΈ°ν μ€ μ€λ₯ λ°μ: {e}", exc_info=True) |
|
|
app_ready = False |
|
|
|
|
|
|
|
|
init_thread = threading.Thread(target=background_init) |
|
|
init_thread.daemon = True |
|
|
init_thread.start() |
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
"""λ©μΈ νμ΄μ§""" |
|
|
if not app_ready: |
|
|
return render_template('loading.html') |
|
|
return render_template('index.html') |
|
|
|
|
|
@app.route('/api/status') |
|
|
def app_status(): |
|
|
"""μ± μ΄κΈ°ν μν νμΈ API""" |
|
|
return jsonify({"ready": app_ready}) |
|
|
|
|
|
@app.route('/api/chat', methods=['POST']) |
|
|
def chat(): |
|
|
"""ν
μ€νΈ κΈ°λ° μ±λ΄ API""" |
|
|
global retriever, app_ready |
|
|
|
|
|
|
|
|
if not app_ready: |
|
|
return jsonify({"error": "μ±μ΄ μμ§ μ΄κΈ°ν μ€μ
λλ€. μ μ ν λ€μ μλν΄μ£ΌμΈμ."}), 503 |
|
|
|
|
|
try: |
|
|
data = request.get_json() |
|
|
if not data or 'query' not in data: |
|
|
return jsonify({"error": "μΏΌλ¦¬κ° μ 곡λμ§ μμμ΅λλ€."}), 400 |
|
|
|
|
|
query = data['query'] |
|
|
logger.info(f"쿼리 μμ : {query}") |
|
|
|
|
|
|
|
|
search_results = retriever.search(query, top_k=5, first_stage_k=20) |
|
|
|
|
|
|
|
|
context = DocumentProcessor.prepare_rag_context(search_results, field="text") |
|
|
|
|
|
if not context: |
|
|
logger.warning("κ²μ κ²°κ³Όκ° μμ΅λλ€.") |
|
|
return jsonify({ |
|
|
"answer": "μ£μ‘ν©λλ€. κ΄λ ¨ μ 보λ₯Ό μ°Ύμ μ μμ΅λλ€.", |
|
|
"sources": [] |
|
|
}) |
|
|
|
|
|
|
|
|
answer = llm_client.rag_generate(query, context) |
|
|
|
|
|
|
|
|
sources = [] |
|
|
for result in search_results: |
|
|
if "source" in result: |
|
|
source_info = { |
|
|
"source": result.get("source", "Unknown"), |
|
|
"score": result.get("rerank_score", result.get("score", 0)) |
|
|
} |
|
|
sources.append(source_info) |
|
|
|
|
|
return jsonify({ |
|
|
"answer": answer, |
|
|
"sources": sources |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"μ±ν
μ²λ¦¬ μ€ μ€λ₯ λ°μ: {e}", exc_info=True) |
|
|
return jsonify({"error": f"μ²λ¦¬ μ€ μ€λ₯κ° λ°μνμ΅λλ€: {str(e)}"}), 500 |
|
|
|
|
|
@app.route('/api/voice', methods=['POST']) |
|
|
def voice_chat(): |
|
|
""" |
|
|
μμ± μ± API μλν¬μΈνΈ: μ€λμ€ νμΌμ λ°μ ν
μ€νΈλ‘ λ³ννκ³ , μ§λ¬Έμ λν μλ΅κ³Ό μμ€λ₯Ό λ°ν |
|
|
|
|
|
Returns: |
|
|
JSON μλ΅: |
|
|
- transcription: μΈμλ ν
μ€νΈ |
|
|
- answer: LLMμμ μμ±ν μλ΅ |
|
|
- sources: κ²μλ λ¬Έμ μμ€ (리μ€νΈ) |
|
|
- error: μ€λ₯ λ°μ μ μ€λ₯ λ©μμ§ |
|
|
- details: μ€λ₯ μμΈ μ 보 (μ νμ ) |
|
|
""" |
|
|
logger.info("μμ± μ± μμ² μμ ") |
|
|
|
|
|
|
|
|
if 'audio' not in request.files: |
|
|
logger.error("μ€λμ€ νμΌμ΄ μ 곡λμ§ μμ") |
|
|
return jsonify({"error": "μ€λμ€ νμΌμ΄ μ 곡λμ§ μμμ΅λλ€."}), 400 |
|
|
|
|
|
audio_file = request.files['audio'] |
|
|
logger.info(f"μμ λ νμΌ: {audio_file.filename}") |
|
|
|
|
|
try: |
|
|
|
|
|
with audio_file.stream as f: |
|
|
audio_bytes = f.read() |
|
|
|
|
|
|
|
|
stt = VitoSTT() |
|
|
stt_result = stt.transcribe_audio(audio_bytes, language="ko") |
|
|
|
|
|
if not stt_result["success"]: |
|
|
logger.error(f"μμ±μΈμ μ€ν¨: {stt_result['error']}") |
|
|
return jsonify({ |
|
|
"error": stt_result["error"], |
|
|
"details": stt_result.get("details", "") |
|
|
}), 500 |
|
|
|
|
|
transcription = stt_result["text"] |
|
|
if not transcription: |
|
|
logger.warning("μμ±μΈμ κ²°κ³Όκ° λΉμ΄μμ΅λλ€.") |
|
|
return jsonify({"error": "μμ±μμ ν
μ€νΈλ₯Ό μΈμνμ§ λͺ»νμ΅λλ€."}), 400 |
|
|
|
|
|
logger.info(f"μμ±μΈμ μ±κ³΅: {transcription[:50]}...") |
|
|
|
|
|
|
|
|
sources = retriever.search(transcription) |
|
|
if not sources: |
|
|
logger.warning("κ²μλ μμ€κ° μμ΅λλ€.") |
|
|
sources = [] |
|
|
|
|
|
|
|
|
context = "\n".join([doc["text"] for doc in sources]) |
|
|
logger.info(f"κ²μλ μμ€ μ: {len(sources)}") |
|
|
|
|
|
|
|
|
prompt = f"μ§λ¬Έ: {transcription}\n\n컨ν
μ€νΈ:\n{context}\n\nλ΅λ³:" |
|
|
answer = llm_client.generate(prompt) |
|
|
|
|
|
if not answer: |
|
|
logger.error("LLM μλ΅ μμ± μ€ν¨") |
|
|
return jsonify({"error": "μλ΅ μμ±μ μ€ν¨νμ΅λλ€."}), 500 |
|
|
|
|
|
logger.info(f"LLM μλ΅ μμ± μ±κ³΅: {answer[:50]}...") |
|
|
|
|
|
|
|
|
return jsonify({ |
|
|
"transcription": transcription, |
|
|
"answer": answer, |
|
|
"sources": sources |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"μμ± μ± μ²λ¦¬ μ€ μ€λ₯ λ°μ: {str(e)}", exc_info=True) |
|
|
return jsonify({ |
|
|
"error": "μμ± μ²λ¦¬ μ€ λ΄λΆ μ€λ₯ λ°μ", |
|
|
"details": str(e) |
|
|
}), 500 |
|
|
|
|
|
@app.route('/api/upload', methods=['POST']) |
|
|
def upload_document(): |
|
|
"""μ§μλ² μ΄μ€ λ¬Έμ μ
λ‘λ API""" |
|
|
global base_retriever, retriever, app_ready |
|
|
|
|
|
|
|
|
if not app_ready: |
|
|
return jsonify({"error": "μ±μ΄ μμ§ μ΄κΈ°ν μ€μ
λλ€. μ μ ν λ€μ μλν΄μ£ΌμΈμ."}), 503 |
|
|
|
|
|
try: |
|
|
|
|
|
if 'document' not in request.files: |
|
|
return jsonify({"error": "λ¬Έμ νμΌμ΄ μ 곡λμ§ μμμ΅λλ€."}), 400 |
|
|
|
|
|
doc_file = request.files['document'] |
|
|
logger.info(f"λ°μ νμΌλͺ
: {doc_file.filename}") |
|
|
|
|
|
|
|
|
if doc_file.filename == '': |
|
|
return jsonify({"error": "μ νλ νμΌμ΄ μμ΅λλ€."}), 400 |
|
|
|
|
|
|
|
|
if not allowed_doc_file(doc_file.filename): |
|
|
logger.error(f"νμ©λμ§ μλ νμΌ νμ: {doc_file.filename}") |
|
|
return jsonify({"error": "νμ©λμ§ μλ νμΌ νμμ
λλ€. νμ¬ νμ©λ νμΌ νμ: {}".format(', '.join(ALLOWED_DOC_EXTENSIONS))}), 400 |
|
|
|
|
|
|
|
|
filename = secure_filename(doc_file.filename) |
|
|
|
|
|
|
|
|
filepath = os.path.join(app.config['DATA_FOLDER'], filename) |
|
|
doc_file.save(filepath) |
|
|
|
|
|
logger.info(f"λ¬Έμκ° μ μ₯λμμ΅λλ€: {filepath}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
|
content = f.read() |
|
|
except UnicodeDecodeError: |
|
|
|
|
|
logger.info(f"UTF-8 λμ½λ© μ€ν¨, CP949λ‘ μλ: {filename}") |
|
|
with open(filepath, 'r', encoding='cp949') as f: |
|
|
content = f.read() |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"source": filename, |
|
|
"filename": filename, |
|
|
"filetype": filename.rsplit('.', 1)[1].lower(), |
|
|
"filepath": filepath |
|
|
} |
|
|
|
|
|
|
|
|
docs = DocumentProcessor.text_to_documents( |
|
|
content, |
|
|
metadata=metadata, |
|
|
chunk_size=512, |
|
|
chunk_overlap=50 |
|
|
) |
|
|
|
|
|
if docs: |
|
|
logger.info(f"{len(docs)}κ° λ¬Έμ μ²ν¬λ₯Ό κ²μκΈ°μ μΆκ°ν©λλ€...") |
|
|
base_retriever.add_documents(docs) |
|
|
|
|
|
|
|
|
logger.info(f"κ²μκΈ° μνλ₯Ό μ μ₯ν©λλ€...") |
|
|
index_path = app.config['INDEX_PATH'] |
|
|
try: |
|
|
base_retriever.save(index_path) |
|
|
logger.info("μΈλ±μ€ μ μ₯ μλ£") |
|
|
except Exception as e: |
|
|
logger.error(f"μΈλ±μ€ μ μ₯ μ€ μ€λ₯ λ°μ: {e}") |
|
|
return jsonify({"error": f"μΈλ±μ€ μ μ₯ μ€ μ€λ₯: {str(e)}"}), 500 |
|
|
|
|
|
return jsonify({ |
|
|
"success": True, |
|
|
"message": f"νμΌ '{filename}'κ° μ±κ³΅μ μΌλ‘ μ
λ‘λλκ³ {len(docs)}κ° μ²ν¬κ° μΆκ°λμμ΅λλ€." |
|
|
}) |
|
|
else: |
|
|
logger.warning(f"νμΌ '{filename}'μμ μ²λ¦¬ν λ¬Έμκ° μμ΅λλ€.") |
|
|
return jsonify({ |
|
|
"warning": True, |
|
|
"message": f"νμΌ '{filename}'μ΄ μ μ₯λμμ§λ§ μ²λ¦¬ν λ΄μ©μ΄ μμ΅λλ€." |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"λ¬Έμ '{filename}' μ²λ¦¬ μ€ μ€λ₯ λ°μ: {e}", exc_info=True) |
|
|
return jsonify({"error": f"λ¬Έμ μ²λ¦¬ μ€ μ€λ₯: {str(e)}"}), 500 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"νμΌ μ
λ‘λ μ€ μ€λ₯ λ°μ: {e}", exc_info=True) |
|
|
return jsonify({"error": f"νμΌ μ
λ‘λ μ€ μ€λ₯: {str(e)}"}), 500 |
|
|
|
|
|
@app.route('/api/documents', methods=['GET']) |
|
|
def list_documents(): |
|
|
"""μ§μλ² μ΄μ€ λ¬Έμ λͺ©λ‘ API""" |
|
|
global base_retriever, retriever, app_ready |
|
|
|
|
|
|
|
|
if not app_ready: |
|
|
return jsonify({"error": "μ±μ΄ μμ§ μ΄κΈ°ν μ€μ
λλ€. μ μ ν λ€μ μλν΄μ£ΌμΈμ."}), 503 |
|
|
|
|
|
try: |
|
|
|
|
|
sources = {} |
|
|
|
|
|
if base_retriever and base_retriever.documents: |
|
|
for doc in base_retriever.documents: |
|
|
source = doc.get("source", "unknown") |
|
|
if source in sources: |
|
|
sources[source]["chunks"] += 1 |
|
|
else: |
|
|
sources[source] = { |
|
|
"filename": doc.get("filename", source), |
|
|
"chunks": 1, |
|
|
"filetype": doc.get("filetype", "unknown") |
|
|
} |
|
|
|
|
|
|
|
|
documents = [] |
|
|
for source, info in sources.items(): |
|
|
documents.append({ |
|
|
"source": source, |
|
|
"filename": info["filename"], |
|
|
"chunks": info["chunks"], |
|
|
"filetype": info["filetype"] |
|
|
}) |
|
|
|
|
|
|
|
|
documents.sort(key=lambda x: x["chunks"], reverse=True) |
|
|
|
|
|
return jsonify({ |
|
|
"documents": documents, |
|
|
"total_documents": len(documents), |
|
|
"total_chunks": sum(doc["chunks"] for doc in documents) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"λ¬Έμ λͺ©λ‘ μ‘°ν μ€ μ€λ₯ λ°μ: {e}", exc_info=True) |
|
|
return jsonify({"error": f"λ¬Έμ λͺ©λ‘ μ‘°ν μ€ μ€λ₯: {str(e)}"}), 500 |
|
|
|
|
|
|
|
|
@app.route('/static/<path:path>') |
|
|
def send_static(path): |
|
|
return send_from_directory('static', path) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
app.run(debug=False, host='0.0.0.0', port=5000) |
|
|
|