Spaces:
Sleeping
Sleeping
| import os | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| from google import genai | |
| import gradio as gr | |
| # === ํ๊ฒฝ ์ค์ === | |
| DB_DIR = os.getenv("CHROMA_DB_DIR", os.path.join(os.getcwd(), "chromadb_KH_media")) | |
| os.environ["CHROMA_DB_DIR"] = DB_DIR | |
| API_KEY = os.getenv("GOOGLE_API_KEY", "AIzaSyCoglAa_T_27Qu-nVULgvsV9oPlJxNGS2k") | |
| # === Simple RAG ์์คํ === | |
| class SimpleRAGSystem: | |
| def __init__(self, db_path=None, collection_name="KH_media_docs"): | |
| path = db_path or DB_DIR | |
| self.encoder = SentenceTransformer("snunlp/KR-SBERT-V40K-klueNLI-augSTS") | |
| self.client = chromadb.PersistentClient(path=path) | |
| self.collection = self.client.get_collection(name=collection_name) | |
| self.available = self.collection.count() > 0 | |
| def search(self, query, top_k=10): | |
| if not self.available: | |
| return [] | |
| emb = self.encoder.encode(query).tolist() | |
| result = self.collection.query( | |
| query_embeddings=[emb], | |
| n_results=top_k, | |
| include=["documents"] | |
| ) | |
| return result.get("documents", [[]])[0] | |
| rag = SimpleRAGSystem() | |
| # === Google GenAI ํด๋ผ์ด์ธํธ === | |
| client = genai.Client(api_key=API_KEY) | |
| # === ์์คํ ๋ฉ์์ง === | |
| SYSTEM_MSG = """ | |
| ๋น์ ์ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ์ ๋ฌธ ์๋ด AI์ ๋๋ค. | |
| # ์ฃผ์ ์ญํ : | |
| - ์ ๊ณต๋ ๋ฌธ์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๋ต๋ณ ์ ๊ณต | |
| - ๋ฏธ๋์ดํ๊ณผ ๊ด๋ จ ์ง๋ฌธ์ ์น์ ํ๊ณ ๊ตฌ์ฒด์ ์ผ๋ก ์๋ต | |
| - ๋ฌธ์์ ์๋ ๋ด์ฉ์ ์ผ๋ฐ ์ง์์ผ๋ก ๋ณด์ (๋จ, ๋ช ์) | |
| # ๋ต๋ณ ์คํ์ผ: | |
| - ์์ธํ๊ณ ํ๋ถํ ์ค๋ช ์ ํฌํจํ์ฌ ์์ธํ๊ณ ๊ธธ๊ฒ ๋ต๋ณ ์ ๊ณต | |
| - ์น๊ทผํ๊ณ ๋์์ด ๋๋ ์๋ด์ฌ ํค | |
| - ํต์ฌ ์ ๋ณด๋ฅผ ๋ช ํํ๊ฒ ์ ๋ฌ | |
| - ์ถ๊ฐ ๊ถ๊ธํ ์ ์ด ์์ผ๋ฉด ์ธ์ ๋ ๋ฌผ์ด๋ณด๋ผ๊ณ ์๋ด | |
| # ์ฐธ๊ณ ๋ฌธ์ ํ์ฉ: | |
| - ๋ฌธ์ ๋ด์ฉ์ด ์์ผ๋ฉด ๊ตฌ์ฒด์ ์ผ๋ก ์ธ์ฉ | |
| - ์ฌ๋ฌ ๋ฌธ์์ ์ ๋ณด๋ฅผ ์ข ํฉํ์ฌ ๋ต๋ณ ์์ฑ | |
| - ์ ํํ์ง ์์ ์ ๋ณด๋ ์ถ์ธกํ์ง ๋ง๊ณ ์์งํ๊ฒ ๋ชจ๋ฅธ๋ค๊ณ ๋ต๋ณ | |
| # ํ์ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ ๊ต์์ง: | |
| ์ด์ธํฌ, ๊นํ์ฉ, ๋ฐ์ข ๋ฏผ, ํ์ง์, ์ด์ ๊ต, ์ด๊ธฐํ, ์ด์ ์, ์กฐ์์, ์ด์ข ํ, ์ด๋ํฉ, ์ด์์, ์ดํ, ์ต์์ง, ์ต๋ฏผ์, ๊น๊ดํธ | |
| """ | |
| # === ์๋ต ํจ์ === | |
| def respond(message, history, system_message, max_tokens, temperature, top_p, model_name): | |
| docs = rag.search(message) if rag.available else [] | |
| ctx = "\n".join(f"์ฐธ๊ณ ๋ฌธ์{i+1}: {d}" for i, d in enumerate(docs)) | |
| sys_msg = system_message + ("\n# ์ฐธ๊ณ ๋ฌธ์:\n" + ctx if ctx else "") | |
| convo = "".join(f"์ฌ์ฉ์: {u}\nAI: {a}\n" for u, a in history) | |
| prompt = f"{sys_msg}\n{convo}์ฌ์ฉ์: {message}\nAI:" | |
| try: | |
| response = client.models.generate_content( | |
| model=model_name, | |
| contents=prompt, | |
| config={ | |
| "max_output_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p | |
| } | |
| ) | |
| return response.text or "์๋ต์ด ์์ต๋๋ค." | |
| except Exception as e: | |
| err = str(e).lower() | |
| if "quota" in err: | |
| return "API ํ ๋น๋์ ์ด๊ณผํ์ต๋๋ค. ๋์ค์ ์๋ํด์ฃผ์ธ์." | |
| if "authentication" in err: | |
| return "์ธ์ฆ ์ค๋ฅ: API ํค๋ฅผ ํ์ธํ์ธ์." | |
| return f"์ค๋ฅ ๋ฐ์: {e}" | |
| # === Gradio ์ธํฐํ์ด์ค === | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="๐ฌ ๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ AI ์๋ด์ฌ", | |
| description="๊ฒฝํฌ๋ํ๊ต ๋ฏธ๋์ดํ๊ณผ์ ๋ํด ๋ฌผ์ด๋ณด์ธ์!", | |
| additional_inputs=[ | |
| gr.Slider(128, 2048, value=1024, step=64, label="์ต๋ ํ ํฐ"), | |
| gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"), | |
| gr.Dropdown( | |
| choices=[ | |
| "gemini-2.0-flash", "gemini-2.0-flash-lite", | |
| "gemini-1.5-flash", "gemini-1.5-pro", | |
| "gemma-3-27b-it", "gemma-3-12b-it", "gemma-3-4b-it" | |
| ], | |
| value="gemini-2.0-flash", | |
| label="๋ชจ๋ธ ์ ํ" | |
| ) | |
| ], | |
| additional_inputs_accordion="๐ง ๊ณ ๊ธ ์ค์ ", | |
| examples=[ | |
| ["๋ฏธ๋์ดํ๊ณผ์์ ๋ฐฐ์ฐ๋ ์ฃผ์ ๊ณผ๋ชฉ๋ค์ ๋ฌด์์ธ๊ฐ์?"], | |
| ["๋ฏธ๋์ดํ๊ณผ ๊ต์์ง์ ์๊ฐํด์ฃผ์ธ์."], | |
| ["๋ฏธ๋์ดํ๊ณผ ์กธ์ ํ ์ง๋ก๋ ์ด๋ป๊ฒ ๋๋์?"], | |
| ["๋ฏธ๋์ดํ๊ณผ ์ ํ ์ ํ์ ๋ํด ์๋ ค์ฃผ์ธ์."], | |
| ["๋ฏธ๋์ดํ๊ณผ ๋์๋ฆฌ๋ ํ์ ํ๋์ ์ด๋ค ๊ฒ๋ค์ด ์๋์?"] | |
| ], | |
| type="messages", | |
| theme="soft", | |
| analytics_enabled=False | |
| ) | |
| # === ์์คํ ์ ๋ณด ํ์ === | |
| with demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(f""" | |
| --- | |
| ### โ๏ธ ์์คํ ์ ๋ณด | |
| **์ธ์ด ๋ชจ๋ธ**: Google Gemini 2.0 Flash, Gemma 3 (4B/12B/27B) ์ ํ ๊ฐ๋ฅ | |
| **์๋ฒ ๋ฉ ๋ชจ๋ธ**: snunlp/KR-SBERT-V40K-klueNLI-augSTS (ํ๊ตญ์ด ํนํ) | |
| **RAG ์ํ**: {"โ ํ์ฑํ" if rag.available else "โ ๋นํ์ฑํ"} | |
| **๋ฌธ์ ์**: {rag.collection.count() if rag.available else "0"}๊ฐ | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |