KB_AI_Challenge / app.py
nneans's picture
Update app.py
f438fbf verified
# =========================================================
# KB AI Challenge - Professional RAG System (Multilingual)
# =========================================================
import os
import sys
import numpy as np
import traceback
import fitz # PyMuPDF
from typing import List
# --- 라이브러리 μž„ν¬νŠΈ ---
import gradio as gr
import speech_recognition as sr
from dotenv import load_dotenv
# .env λ‘œλ“œ
load_dotenv()
from deep_translator import GoogleTranslator
from sentence_transformers import SentenceTransformer
from groq import Groq
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
try:
from langchain.text_splitter import RecursiveCharacterTextSplitter
except ImportError:
from langchain_text_splitters import RecursiveCharacterTextSplitter
# =========================================================
# 1. μ„€μ • 및 μ΄ˆκΈ°ν™”
# =========================================================
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "your_groq_api_key_here")
EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask"
GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
COLLECTION_NAME = "local_kb"
print("πŸ› οΈ μ‹œμŠ€ν…œ μ΄ˆκΈ°ν™” 쀑... (System Init)")
# λͺ¨λΈ λ‘œλ“œ
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
embedding_model.max_seq_length = 512
# Qdrant (λ©”λͺ¨λ¦¬)
qdrant_client = QdrantClient(":memory:")
try:
qdrant_client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=768, distance=Distance.COSINE),
)
print(f"βœ… Qdrant Collection Ready.")
except Exception as e:
print(f"❌ Qdrant Error: {e}")
# Groq Init
groq_client = None
if GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here":
try:
groq_client = Groq(api_key=GROQ_API_KEY)
except Exception as e:
print(f"❌ Groq Error: {e}")
else:
print("⚠️ Groq API Key Missing.")
doc_id_counter = 0
print("βœ… System Ready.")
# =========================================================
# 2. λ‹€κ΅­μ–΄ 지원 둜직 (Translation & STT)
# =========================================================
LANG_MAP = {
"ν•œκ΅­μ–΄ (Korean)": {"code": "ko", "stt": "ko-KR"},
"English (μ˜μ–΄)": {"code": "en", "stt": "en-US"},
"ζ—₯本θͺž (Japanese)": {"code": "ja", "stt": "ja-JP"},
"δΈ­ζ–‡ (Chinese)": {"code": "zh-CN", "stt": "zh-CN"}
}
def translate_text(text, target_lang_code):
try:
if target_lang_code == "ko": return text
return GoogleTranslator(source='auto', target=target_lang_code).translate(text)
except:
return text
def translate_to_korean(text):
try:
return GoogleTranslator(source='auto', target='ko').translate(text)
except:
return text
# =========================================================
# 3. 핡심 둜직 (RAG Pipeline)
# =========================================================
def process_uploaded_files(files):
"""PDF 처리 및 μž„λ² λ”©"""
global doc_id_counter
if not files: return "파일이 μ„ νƒλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."
total_chunks = 0
status_msg = ""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len)
for file in files:
try:
file_path = file.name if hasattr(file, 'name') else file
doc = fitz.open(file_path)
file_text = ""
for page in doc: file_text += page.get_text()
if not file_text.strip():
status_msg += f"⚠️ {os.path.basename(file_path)}: ν…μŠ€νŠΈ μ—†μŒ.\n"
continue
chunks = text_splitter.split_text(file_text)
points = []
for i, chunk in enumerate(chunks):
vector = embedding_model.encode(chunk).tolist()
payload = {"filename": os.path.basename(file_path), "text": chunk}
points.append(PointStruct(id=doc_id_counter, vector=vector, payload=payload))
doc_id_counter += 1
if points:
qdrant_client.upsert(collection_name=COLLECTION_NAME, points=points)
total_chunks += len(points)
status_msg += f"βœ… {os.path.basename(file_path)} ({len(points)} 개 μ €μž₯됨)\n"
except Exception as e:
status_msg += f"❌ 였λ₯˜: {os.path.basename(file_path)} - {str(e)}\n"
return f"총 {total_chunks}개 데이터 처리 μ™„λ£Œ.\n\n{status_msg}"
def search_knowledge_base(query, top_k=5):
try:
query_vector = embedding_model.encode(query).tolist()
res = qdrant_client.query_points(
collection_name=COLLECTION_NAME, query=query_vector, limit=top_k, with_payload=True
)
return res.points
except:
return []
def generate_answer_groq(query, context_text):
if not groq_client: return "API ν‚€κ°€ ν•„μš”ν•©λ‹ˆλ‹€."
system_prompt = """
당신은 KB 금육그룹의 μ „λ¬Έ AI μ–΄μ‹œμŠ€ν„΄νŠΈμž…λ‹ˆλ‹€.
제곡된 [λ¬Έλ§₯]에 κΈ°λ°˜ν•˜μ—¬ μ§ˆλ¬Έμ— λŒ€ν•΄ μ •ν™•ν•˜κ³  전문적인 닡변을 μž‘μ„±ν•˜μ„Έμš”.
λͺ¨λ₯΄λŠ” λ‚΄μš©μ€ λͺ¨λ₯Έλ‹€κ³  λ‹΅ν•˜κ³ , μΆ”μΈ‘ν•˜μ§€ λ§ˆμ„Έμš”.
닡변은 ν•œκ΅­μ–΄λ‘œ μž‘μ„±ν•˜μ„Έμš”.
"""
user_prompt = f"질문: {query}\n\n[λ¬Έλ§₯]\n{context_text}"
try:
response = groq_client.chat.completions.create(
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
model=GROQ_MODEL_NAME, temperature=0.1
)
return response.choices[0].message.content
except Exception as e:
return f"응닡 생성 였λ₯˜: {e}"
def run_rag_chat(message, history, lang_selection):
if not message: return "", history, ""
target_lang = LANG_MAP[lang_selection]["code"]
# 1. μž…λ ₯ λ²ˆμ—­ (Target -> Korean)
korean_query = message
if target_lang != "ko":
korean_query = translate_to_korean(message)
# 2. 검색 & λ‹΅λ³€ 생성 (Korean)
hits = search_knowledge_base(korean_query)
if not hits:
bot_response_ko = "μ£„μ†‘ν•©λ‹ˆλ‹€. κ΄€λ ¨ 정보λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€."
reference_text = "μ°Έκ³  λ¬Έμ„œ μ—†μŒ"
else:
context_text = "\n\n".join([h.payload['text'] for h in hits])
# 쀑볡 제거 및 κ·Έλ£Ήν™” (File grouping)
ref_data = {}
for h in hits:
fname = h.payload['filename']
if fname not in ref_data:
ref_data[fname] = []
ref_data[fname].append(h.score)
refs = []
for fname, scores in ref_data.items():
refs.append(f"- {fname} (κ΄€λ ¨ λ‚΄μš© {len(scores)}건, 졜고 μœ μ‚¬λ„: {max(scores):.2f})")
reference_text = "\n".join(refs)
bot_response_ko = generate_answer_groq(korean_query, context_text)
# 3. λ‹΅λ³€ λ²ˆμ—­ (Korean -> Target)
final_response = bot_response_ko
if target_lang != "ko":
translated_response = translate_text(bot_response_ko, target_lang)
final_response = f"{translated_response}\n\n---\n[ν•œκ΅­μ–΄ 원문]\n{bot_response_ko}"
# νžˆμŠ€ν† λ¦¬μ— μΆ”κ°€ (Messages Format for Gradio 6.x)
new_history = history + [
{"role": "user", "content": message},
{"role": "assistant", "content": final_response}
]
return "", new_history, reference_text
def voice_to_text_chat(audio, history, lang_selection):
if audio is None: return "", history, "μŒμ„± μž…λ ₯ μ—†μŒ"
stt_lang = LANG_MAP[lang_selection]["stt"]
try:
sample_rate, audio_numpy = audio
if audio_numpy.dtype == np.float32:
audio_numpy = (audio_numpy * 32767).astype(np.int16)
if len(audio_numpy.shape) > 1:
audio_numpy = audio_numpy.mean(axis=1).astype(np.int16)
audio_data = sr.AudioData(audio_numpy.tobytes(), sample_rate, 2)
r = sr.Recognizer()
# μ„ νƒλœ μ–Έμ–΄λ‘œ 인식
text = r.recognize_google(audio_data, language=stt_lang)
# μ±„νŒ… ν•¨μˆ˜ 호좜
return run_rag_chat(text, history, lang_selection)
except sr.UnknownValueError:
return "", history, "μŒμ„±μ„ 이해할 수 μ—†μŠ΅λ‹ˆλ‹€."
except Exception as e:
return "", history, f"였λ₯˜: {e}"
# =========================================================
# 4. UI Layout (Clean Professional Korean)
# =========================================================
theme = gr.themes.Soft(
primary_hue="amber",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Noto Sans KR"), "sans-serif"]
)
css = """
footer {visibility: hidden !important;}
.gradio-container {min-height: 0px !important;}
"""
with gr.Blocks(theme=theme, title="KB AI Challenge", css=css) as demo:
with gr.Row():
# --- LEFT SIDEBAR ---
with gr.Column(scale=1, min_width=300, variant="panel"):
gr.Markdown("## KB AI Challenge")
gr.Markdown("**λ‹€κ΅­μ–΄ 금육 AI μ–΄μ‹œμŠ€ν„΄νŠΈ**")
with gr.Group():
lang_dropdown = gr.Dropdown(
choices=list(LANG_MAP.keys()),
value="ν•œκ΅­μ–΄ (Korean)",
label="μ–Έμ–΄ μ„€μ •",
interactive=True
)
file_input = gr.File(label="지식 베이슀 (PDF)", file_count="multiple", file_types=[".pdf"])
with gr.Row():
upload_btn = gr.Button("μ—…λ‘œλ“œ 및 뢄석", variant="primary", size="sm")
upload_status = gr.Textbox(show_label=False, placeholder="μƒνƒœ λŒ€κΈ° 쀑...", interactive=False, lines=1, max_lines=1)
gr.Markdown("### μŒμ„± λŒ€ν™”")
audio_input = gr.Audio(sources=["microphone"], type="numpy", label="μŒμ„± μž…λ ₯", show_label=False)
with gr.Accordion("μ‹œμŠ€ν…œ μ•„ν‚€ν…μ²˜", open=False):
gr.Markdown(
"""
**μ΅œμ ν™” λ‚΄μ—­**
1. **STT**: Google Speech API
2. **λ²ˆμ—­**: Google Translate API
3. **LLM**: Groq LPU (Llama 3)
"""
)
# --- RIGHT MAIN ---
with gr.Column(scale=3):
# chatbot (Messages format)
chatbot = gr.Chatbot(label="λŒ€ν™”", height=500, show_label=False)
# References
gr.Markdown("**μ°Έκ³  λ¬Έμ„œ**")
ref_output = gr.Textbox(show_label=False, interactive=False, lines=3, max_lines=5, placeholder="κ΄€λ ¨ λ¬Έμ„œκ°€ ν‘œμ‹œλ©λ‹ˆλ‹€.")
# Input Area
with gr.Row():
msg = gr.Textbox(
scale=6,
show_label=False,
placeholder="μ§ˆλ¬Έμ„ μž…λ ₯ν•˜μ„Έμš”...",
container=False
)
submit_btn = gr.Button("전솑", scale=1, variant="primary")
# --- Event Handlers ---
upload_btn.click(process_uploaded_files, inputs=[file_input], outputs=[upload_status])
msg.submit(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output])
submit_btn.click(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output])
audio_input.stop_recording(voice_to_text_chat, [audio_input, chatbot, lang_dropdown], [msg, chatbot, ref_output])
if __name__ == "__main__":
demo.launch(share=True)