Spaces:
Sleeping
Sleeping
| # ========================================================= | |
| # KB AI Challenge - Professional RAG System (Multilingual) | |
| # ========================================================= | |
| import os | |
| import sys | |
| import numpy as np | |
| import traceback | |
| import fitz # PyMuPDF | |
| from typing import List | |
| # --- λΌμ΄λΈλ¬λ¦¬ μν¬νΈ --- | |
| import gradio as gr | |
| import speech_recognition as sr | |
| from dotenv import load_dotenv | |
| # .env λ‘λ | |
| load_dotenv() | |
| from deep_translator import GoogleTranslator | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct | |
| try: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| except ImportError: | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # ========================================================= | |
| # 1. μ€μ λ° μ΄κΈ°ν | |
| # ========================================================= | |
| GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "your_groq_api_key_here") | |
| EMBEDDING_MODEL_NAME = "jhgan/ko-sroberta-multitask" | |
| GROQ_MODEL_NAME = "llama-3.3-70b-versatile" | |
| COLLECTION_NAME = "local_kb" | |
| print("π οΈ μμ€ν μ΄κΈ°ν μ€... (System Init)") | |
| # λͺ¨λΈ λ‘λ | |
| embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| embedding_model.max_seq_length = 512 | |
| # Qdrant (λ©λͺ¨λ¦¬) | |
| qdrant_client = QdrantClient(":memory:") | |
| try: | |
| qdrant_client.recreate_collection( | |
| collection_name=COLLECTION_NAME, | |
| vectors_config=VectorParams(size=768, distance=Distance.COSINE), | |
| ) | |
| print(f"β Qdrant Collection Ready.") | |
| except Exception as e: | |
| print(f"β Qdrant Error: {e}") | |
| # Groq Init | |
| groq_client = None | |
| if GROQ_API_KEY and GROQ_API_KEY != "your_groq_api_key_here": | |
| try: | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| except Exception as e: | |
| print(f"β Groq Error: {e}") | |
| else: | |
| print("β οΈ Groq API Key Missing.") | |
| doc_id_counter = 0 | |
| print("β System Ready.") | |
| # ========================================================= | |
| # 2. λ€κ΅μ΄ μ§μ λ‘μ§ (Translation & STT) | |
| # ========================================================= | |
| LANG_MAP = { | |
| "νκ΅μ΄ (Korean)": {"code": "ko", "stt": "ko-KR"}, | |
| "English (μμ΄)": {"code": "en", "stt": "en-US"}, | |
| "ζ₯ζ¬θͺ (Japanese)": {"code": "ja", "stt": "ja-JP"}, | |
| "δΈζ (Chinese)": {"code": "zh-CN", "stt": "zh-CN"} | |
| } | |
| def translate_text(text, target_lang_code): | |
| try: | |
| if target_lang_code == "ko": return text | |
| return GoogleTranslator(source='auto', target=target_lang_code).translate(text) | |
| except: | |
| return text | |
| def translate_to_korean(text): | |
| try: | |
| return GoogleTranslator(source='auto', target='ko').translate(text) | |
| except: | |
| return text | |
| # ========================================================= | |
| # 3. ν΅μ¬ λ‘μ§ (RAG Pipeline) | |
| # ========================================================= | |
| def process_uploaded_files(files): | |
| """PDF μ²λ¦¬ λ° μλ² λ©""" | |
| global doc_id_counter | |
| if not files: return "νμΌμ΄ μ νλμ§ μμμ΅λλ€." | |
| total_chunks = 0 | |
| status_msg = "" | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50, length_function=len) | |
| for file in files: | |
| try: | |
| file_path = file.name if hasattr(file, 'name') else file | |
| doc = fitz.open(file_path) | |
| file_text = "" | |
| for page in doc: file_text += page.get_text() | |
| if not file_text.strip(): | |
| status_msg += f"β οΈ {os.path.basename(file_path)}: ν μ€νΈ μμ.\n" | |
| continue | |
| chunks = text_splitter.split_text(file_text) | |
| points = [] | |
| for i, chunk in enumerate(chunks): | |
| vector = embedding_model.encode(chunk).tolist() | |
| payload = {"filename": os.path.basename(file_path), "text": chunk} | |
| points.append(PointStruct(id=doc_id_counter, vector=vector, payload=payload)) | |
| doc_id_counter += 1 | |
| if points: | |
| qdrant_client.upsert(collection_name=COLLECTION_NAME, points=points) | |
| total_chunks += len(points) | |
| status_msg += f"β {os.path.basename(file_path)} ({len(points)} κ° μ μ₯λ¨)\n" | |
| except Exception as e: | |
| status_msg += f"β μ€λ₯: {os.path.basename(file_path)} - {str(e)}\n" | |
| return f"μ΄ {total_chunks}κ° λ°μ΄ν° μ²λ¦¬ μλ£.\n\n{status_msg}" | |
| def search_knowledge_base(query, top_k=5): | |
| try: | |
| query_vector = embedding_model.encode(query).tolist() | |
| res = qdrant_client.query_points( | |
| collection_name=COLLECTION_NAME, query=query_vector, limit=top_k, with_payload=True | |
| ) | |
| return res.points | |
| except: | |
| return [] | |
| def generate_answer_groq(query, context_text): | |
| if not groq_client: return "API ν€κ° νμν©λλ€." | |
| system_prompt = """ | |
| λΉμ μ KB κΈμ΅κ·Έλ£Ήμ μ λ¬Έ AI μ΄μμ€ν΄νΈμ λλ€. | |
| μ 곡λ [λ¬Έλ§₯]μ κΈ°λ°νμ¬ μ§λ¬Έμ λν΄ μ ννκ³ μ λ¬Έμ μΈ λ΅λ³μ μμ±νμΈμ. | |
| λͺ¨λ₯΄λ λ΄μ©μ λͺ¨λ₯Έλ€κ³ λ΅νκ³ , μΆμΈ‘νμ§ λ§μΈμ. | |
| λ΅λ³μ νκ΅μ΄λ‘ μμ±νμΈμ. | |
| """ | |
| user_prompt = f"μ§λ¬Έ: {query}\n\n[λ¬Έλ§₯]\n{context_text}" | |
| try: | |
| response = groq_client.chat.completions.create( | |
| messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], | |
| model=GROQ_MODEL_NAME, temperature=0.1 | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"μλ΅ μμ± μ€λ₯: {e}" | |
| def run_rag_chat(message, history, lang_selection): | |
| if not message: return "", history, "" | |
| target_lang = LANG_MAP[lang_selection]["code"] | |
| # 1. μ λ ₯ λ²μ (Target -> Korean) | |
| korean_query = message | |
| if target_lang != "ko": | |
| korean_query = translate_to_korean(message) | |
| # 2. κ²μ & λ΅λ³ μμ± (Korean) | |
| hits = search_knowledge_base(korean_query) | |
| if not hits: | |
| bot_response_ko = "μ£μ‘ν©λλ€. κ΄λ ¨ μ 보λ₯Ό μ°Ύμ μ μμ΅λλ€." | |
| reference_text = "μ°Έκ³ λ¬Έμ μμ" | |
| else: | |
| context_text = "\n\n".join([h.payload['text'] for h in hits]) | |
| # μ€λ³΅ μ κ±° λ° κ·Έλ£Ήν (File grouping) | |
| ref_data = {} | |
| for h in hits: | |
| fname = h.payload['filename'] | |
| if fname not in ref_data: | |
| ref_data[fname] = [] | |
| ref_data[fname].append(h.score) | |
| refs = [] | |
| for fname, scores in ref_data.items(): | |
| refs.append(f"- {fname} (κ΄λ ¨ λ΄μ© {len(scores)}건, μ΅κ³ μ μ¬λ: {max(scores):.2f})") | |
| reference_text = "\n".join(refs) | |
| bot_response_ko = generate_answer_groq(korean_query, context_text) | |
| # 3. λ΅λ³ λ²μ (Korean -> Target) | |
| final_response = bot_response_ko | |
| if target_lang != "ko": | |
| translated_response = translate_text(bot_response_ko, target_lang) | |
| final_response = f"{translated_response}\n\n---\n[νκ΅μ΄ μλ¬Έ]\n{bot_response_ko}" | |
| # νμ€ν 리μ μΆκ° (Messages Format for Gradio 6.x) | |
| new_history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": final_response} | |
| ] | |
| return "", new_history, reference_text | |
| def voice_to_text_chat(audio, history, lang_selection): | |
| if audio is None: return "", history, "μμ± μ λ ₯ μμ" | |
| stt_lang = LANG_MAP[lang_selection]["stt"] | |
| try: | |
| sample_rate, audio_numpy = audio | |
| if audio_numpy.dtype == np.float32: | |
| audio_numpy = (audio_numpy * 32767).astype(np.int16) | |
| if len(audio_numpy.shape) > 1: | |
| audio_numpy = audio_numpy.mean(axis=1).astype(np.int16) | |
| audio_data = sr.AudioData(audio_numpy.tobytes(), sample_rate, 2) | |
| r = sr.Recognizer() | |
| # μ νλ μΈμ΄λ‘ μΈμ | |
| text = r.recognize_google(audio_data, language=stt_lang) | |
| # μ±ν ν¨μ νΈμΆ | |
| return run_rag_chat(text, history, lang_selection) | |
| except sr.UnknownValueError: | |
| return "", history, "μμ±μ μ΄ν΄ν μ μμ΅λλ€." | |
| except Exception as e: | |
| return "", history, f"μ€λ₯: {e}" | |
| # ========================================================= | |
| # 4. UI Layout (Clean Professional Korean) | |
| # ========================================================= | |
| theme = gr.themes.Soft( | |
| primary_hue="amber", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("Noto Sans KR"), "sans-serif"] | |
| ) | |
| css = """ | |
| footer {visibility: hidden !important;} | |
| .gradio-container {min-height: 0px !important;} | |
| """ | |
| with gr.Blocks(theme=theme, title="KB AI Challenge", css=css) as demo: | |
| with gr.Row(): | |
| # --- LEFT SIDEBAR --- | |
| with gr.Column(scale=1, min_width=300, variant="panel"): | |
| gr.Markdown("## KB AI Challenge") | |
| gr.Markdown("**λ€κ΅μ΄ κΈμ΅ AI μ΄μμ€ν΄νΈ**") | |
| with gr.Group(): | |
| lang_dropdown = gr.Dropdown( | |
| choices=list(LANG_MAP.keys()), | |
| value="νκ΅μ΄ (Korean)", | |
| label="μΈμ΄ μ€μ ", | |
| interactive=True | |
| ) | |
| file_input = gr.File(label="μ§μ λ² μ΄μ€ (PDF)", file_count="multiple", file_types=[".pdf"]) | |
| with gr.Row(): | |
| upload_btn = gr.Button("μ λ‘λ λ° λΆμ", variant="primary", size="sm") | |
| upload_status = gr.Textbox(show_label=False, placeholder="μν λκΈ° μ€...", interactive=False, lines=1, max_lines=1) | |
| gr.Markdown("### μμ± λν") | |
| audio_input = gr.Audio(sources=["microphone"], type="numpy", label="μμ± μ λ ₯", show_label=False) | |
| with gr.Accordion("μμ€ν μν€ν μ²", open=False): | |
| gr.Markdown( | |
| """ | |
| **μ΅μ ν λ΄μ** | |
| 1. **STT**: Google Speech API | |
| 2. **λ²μ**: Google Translate API | |
| 3. **LLM**: Groq LPU (Llama 3) | |
| """ | |
| ) | |
| # --- RIGHT MAIN --- | |
| with gr.Column(scale=3): | |
| # chatbot (Messages format) | |
| chatbot = gr.Chatbot(label="λν", height=500, show_label=False) | |
| # References | |
| gr.Markdown("**μ°Έκ³ λ¬Έμ**") | |
| ref_output = gr.Textbox(show_label=False, interactive=False, lines=3, max_lines=5, placeholder="κ΄λ ¨ λ¬Έμκ° νμλ©λλ€.") | |
| # Input Area | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| scale=6, | |
| show_label=False, | |
| placeholder="μ§λ¬Έμ μ λ ₯νμΈμ...", | |
| container=False | |
| ) | |
| submit_btn = gr.Button("μ μ‘", scale=1, variant="primary") | |
| # --- Event Handlers --- | |
| upload_btn.click(process_uploaded_files, inputs=[file_input], outputs=[upload_status]) | |
| msg.submit(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output]) | |
| submit_btn.click(run_rag_chat, [msg, chatbot, lang_dropdown], [msg, chatbot, ref_output]) | |
| audio_input.stop_recording(voice_to_text_chat, [audio_input, chatbot, lang_dropdown], [msg, chatbot, ref_output]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |