Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from core.context_manager import ContextManager | |
| from core.make_pipeline import MakePipeline | |
| from core.make_reply import generate_reply | |
| from core.utils import load_config as load_full_config, save_config as save_full_config, load_llm_config | |
| import re | |
| def create_interface(ctx: ContextManager, makePipeline: MakePipeline): | |
| with gr.Blocks(css=""" | |
| .chat-box { max-height: 500px; overflow-y: auto; padding: 10px; border: 1px solid #ccc; border-radius: 10px; } | |
| .bubble-left { background-color: #f1f0f0; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: left; clear: both; } | |
| .bubble-right { background-color: #d1e7ff; border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; float: right; clear: both; text-align: right; } | |
| .reset-btn-container { text-align: right; margin-bottom: 10px; } | |
| """) as demo: | |
| with gr.Tabs(): | |
| ### 1. ์ฑํ ํญ ### | |
| with gr.TabItem("๐ฌ ํ์ง๋ก์ ๋ํํ๊ธฐ"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| gr.Markdown("### ํ์ง๋ก์ ๋ํํ๊ธฐ") | |
| reset_btn = gr.Button("๐ ๋ํ ์ด๊ธฐํ", elem_classes="reset-btn-container", scale=0.25) | |
| chat_output = gr.HTML(elem_id="chat-box") | |
| user_input = gr.Textbox(label="๋ฉ์์ง ์ ๋ ฅ", placeholder="ํ์ง๋ก์๊ฒ ๋ง์ ๊ฑธ์ด๋ณด์ธ์") | |
| state = gr.State(ctx) | |
| # history ์ฝ์ด์ ํ๋ฉด์ ๋ฟ๋ฆฌ๋ ์ญํ | |
| def render_chat(ctx: ContextManager): | |
| def parse_emotion_text(text: str) -> str: | |
| """ | |
| *...* ๋ถ๋ถ์ ํ์ ํ ์คํธ๋ก ๋ฐ๊พธ๊ณ , ์ค๋ฐ๊ฟ์ ์ถ๊ฐํ์ฌ HTML๋ก ๋ฐํ | |
| """ | |
| segments = [] | |
| pattern = re.compile(r"\*(.+?)\*|([^\*]+)") | |
| matches = pattern.findall(text) | |
| for action, plain in matches: | |
| if action: | |
| segments.append(f"<div style='color:gray'>*{action}*</div>") | |
| elif plain: | |
| for line in plain.strip().splitlines(): | |
| line = line.strip() | |
| if line: | |
| segments.append(f"<div>{line}</div>") | |
| return "\n".join(segments) | |
| html = "" | |
| for item in ctx.getHistory(): | |
| parsed = parse_emotion_text(item['text']) | |
| if item["role"] == "user": | |
| html += f"<div class='bubble-right'>{parsed}</div>" | |
| elif item["role"] == "bot": | |
| html += f"<div class='bubble-left'>{parsed}</div>" | |
| return gr.update(value=html) | |
| def on_submit(user_msg: str, ctx: ContextManager): | |
| # ์ฌ์ฉ์ ์ ๋ ฅ history์ ์ถ๊ฐ | |
| ctx.addHistory("user", user_msg) | |
| # ์ฌ์ฉ์ ์ ๋ ฅ์ ํฌํจํ ์ฑํ ์ฐ์ ๋ ๋๋ง | |
| html = render_chat(ctx) | |
| yield html, "", ctx | |
| # ๋ด ์๋ต ์์ฑ | |
| generate_reply(ctx, makePipeline, user_msg) | |
| # ์๋ต์ ํฌํจํ ์ ์ฒด history ๊ธฐ๋ฐ ๋ ๋๋ง | |
| html = render_chat(ctx) | |
| yield html, "", ctx | |
| # history ์ด๊ธฐํ | |
| def reset_chat(ctx: ContextManager): | |
| ctx.clearHistory() | |
| return gr.update(value=""), "", ctx | |
| user_input.submit(on_submit, inputs=[user_input, state], outputs=[chat_output, user_input, state], queue=True) | |
| reset_btn.click(reset_chat, inputs=[state], outputs=[chat_output, user_input, state]) | |
| ### 2. ์ค์ ํญ ### | |
| with gr.TabItem("โ๏ธ ๋ชจ๋ธ ์ค์ "): | |
| gr.Markdown("### LLM ํ๋ผ๋ฏธํฐ ์ค์ ") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p") | |
| repetition_penalty = gr.Slider(0.8, 2.0, value=1.05, step=0.01, label="Repetition Penalty") | |
| with gr.Row(): | |
| max_tokens = gr.Slider(16, 2048, value=96, step=8, label="Max New Tokens") | |
| apply_btn = gr.Button("โ ์ค์ ์ ์ฉ") | |
| def update_config(temp, topp, max_tok, repeat): | |
| makePipeline.update_config({ | |
| "temperature": temp, | |
| "top_p": topp, | |
| "max_new_tokens": max_tok, | |
| "repetition_penalty": repeat | |
| }) | |
| return gr.update(value="โ ์ค์ ์ ์ฉ ์๋ฃ") | |
| # ๐ป ์ค์ ๋ถ๋ฌ์ค๊ธฐ / ๋ด๋ณด๋ด๊ธฐ ๋ฒํผ๋ค | |
| with gr.Row(): | |
| load_btn = gr.Button("๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ") | |
| save_btn = gr.Button("๐พ ์ค์ ๋ด๋ณด๋ด๊ธฐ") | |
| def load_config(): | |
| llm_cfg = load_llm_config("config.json") | |
| return ( | |
| llm_cfg.get("temperature", 0.7), | |
| llm_cfg.get("top_p", 0.9), | |
| llm_cfg.get("repetition_penalty", 1.05), | |
| llm_cfg.get("max_new_tokens", 96), | |
| "๐ ์ค์ ๋ถ๋ฌ์ค๊ธฐ ์๋ฃ" | |
| ) | |
| def save_config(temp, topp, repeat, max_tok): | |
| # ๊ธฐ์กด ์ ์ฒด ์ค์ ๋ถ๋ฌ์ค๊ธฐ | |
| config = load_full_config("config.json") | |
| # LLM ๋ธ๋ก๋ง ์๋ก ๋์ | |
| config["llm"] = { | |
| "temperature": temp, | |
| "top_p": topp, | |
| "repetition_penalty": repeat, | |
| "max_new_tokens": max_tok | |
| } | |
| # ์ ์ฒด ์ ์ฅ | |
| save_full_config(config, path="config.json") | |
| return gr.update(value="๐พ ์ค์ ์ ์ฅ ์๋ฃ") | |
| # โ ๋งจ ์๋์ ์ํ์ฐฝ ๋ฐฐ์น | |
| status = gr.Textbox(label="", interactive=False) | |
| # ๐ ๋ฒํผ ๋์ ์ฐ๊ฒฐ | |
| apply_btn.click( | |
| update_config, | |
| inputs=[temperature, top_p, max_tokens, repetition_penalty], | |
| outputs=[status] # ํน์ [] | |
| ) | |
| load_btn.click( | |
| load_config, | |
| inputs=None, | |
| outputs=[temperature, top_p, repetition_penalty, max_tokens, status] | |
| ) | |
| save_btn.click( | |
| save_config, | |
| inputs=[temperature, top_p, repetition_penalty, max_tokens], | |
| outputs=[status] | |
| ) | |
| ### 3. ํ๋กฌํํธ ํธ์ง ํญ ### | |
| with gr.TabItem("๐ ํ๋กฌํํธ ์ค์ "): | |
| gr.Markdown("### ์ฌ์ฉ์ ๋ฐ ์บ๋ฆญํฐ ์ด๋ฆ ์ค์ ") | |
| with gr.Row(): | |
| user_name = gr.Textbox(label="๐ค ์ฌ์ฉ์ ์ด๋ฆ") | |
| bot_name = gr.Textbox(label="๐ค ์บ๋ฆญํฐ ์ด๋ฆ") | |
| name_status = gr.Textbox(label="", interactive=False) | |
| with gr.Row(): | |
| load_name_btn = gr.Button("๐ ์ด๋ฆ ๋ถ๋ฌ์ค๊ธฐ") | |
| save_name_btn = gr.Button("๐พ ์ด๋ฆ ์ ์ฅํ๊ธฐ") | |
| def load_names(ctx): | |
| cha_cfg = load_full_config("config.json").get("cha", {}) | |
| user = cha_cfg.get("user_name", "user") | |
| bot = cha_cfg.get("bot_name", "Tanjiro") | |
| ctx.setUserName(user) | |
| ctx.setBotName(bot) | |
| return user, bot, "๐ ์ด๋ฆ ๋ถ๋ฌ์ค๊ธฐ ์๋ฃ" | |
| def save_names(user, bot, ctx): | |
| config = load_full_config("config.json") | |
| config["cha"] = { | |
| "user_name": user, | |
| "bot_name": bot | |
| } | |
| save_full_config(config, path="config.json") | |
| ctx.setUserName(user) | |
| ctx.setBotName(bot) | |
| return "๐พ ์ด๋ฆ ์ ์ฅ ์๋ฃ!" | |
| load_name_btn.click( | |
| fn=load_names, | |
| inputs=[state], | |
| outputs=[user_name, bot_name, name_status] | |
| ) | |
| save_name_btn.click( | |
| save_names, | |
| inputs=[user_name, bot_name, state], | |
| outputs=[name_status] | |
| ) | |
| #์ด๊ธฐํ ์์ ์์ ์ด๋ฆ ํ๋ฒ ๋ถ๋ฌ์ค๊ธฐ | |
| demo.load( | |
| fn=load_names, | |
| inputs=[state], | |
| outputs=[user_name, bot_name, name_status] | |
| ) | |
| gr.Markdown("### ์บ๋ฆญํฐ ๋ฐ ์ธ๊ณ๊ด ํ๋กฌํํธ ํธ์ง") | |
| prompt_editor = gr.Textbox( | |
| lines=20, | |
| label="ํ ์คํธ (init.txt)", | |
| placeholder="!! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("#### !! ๋ฐ๋์ ๋ถ๋ฌ์ค๊ธฐ๋ฅผ ๋จผ์ ํ์ธ์ !!") | |
| with gr.Row(): | |
| load_prompt_btn = gr.Button("๐ ํ์ฌ ํ๋กฌํํธ ๋ถ๋ฌ์ค๊ธฐ") | |
| save_prompt_btn = gr.Button("๐พ ์์ฑํ ํ๋กฌํํธ๋ก ๊ต์ฒด") | |
| def load_prompt(): | |
| try: | |
| with open("assets/prompt/init.txt", "r", encoding="utf-8") as f: | |
| return f.read() | |
| except FileNotFoundError: | |
| return "" | |
| def save_prompt(text): | |
| with open("assets/prompt/init.txt", "w", encoding="utf-8") as f: | |
| f.write(text) | |
| return "๐พ ์ ์ฅ ์๋ฃ!" | |
| load_prompt_btn.click( | |
| load_prompt, | |
| inputs=None, | |
| outputs=prompt_editor | |
| ) | |
| save_prompt_btn.click( | |
| save_prompt, | |
| inputs=[prompt_editor], | |
| outputs=[save_prompt_btn] | |
| ) | |
| return demo |