| import gradio as gr |
| import os |
| os.environ["OMP_NUM_THREADS"] = "1" |
| import json |
| import torch |
| from parlai.core.opt import Opt |
| from parlai.zoo.blender.blender_3B import download |
| from parlai.core.agents import Agent |
| from parlai.core.params import ParlaiParser |
| from parlai.core.worlds import DialogPartnerWorld |
| from controllable_blender import ControllableBlender |
| from huggingface_hub import snapshot_download |
| from huggingface_hub import login |
|
|
| torch.set_default_dtype(torch.float16) |
|
|
| token = os.environ.get("Token1") |
|
|
| login(token=token) |
|
|
| snapshot_download(repo_id="shivansarora/ControllableBlender", local_dir="ParlAI/data/models/blender/blender_3B") |
|
|
| |
| agent_opt = json.load(open("blender_3B.opt", 'r')) |
| download(agent_opt["datapath"]) |
| conversation_state = {"world": None, "human_agent": None} |
|
|
| class GradioHumanAgent(Agent): |
| def __init__(self, opt): |
| super().__init__(opt) |
| self.msg = None |
|
|
| def observe(self, msg): |
| return msg |
|
|
| def act(self): |
| return {"text": self.msg, "episode_done": False} |
|
|
|
|
| def init_world(cefr, inference_type): |
| opt = agent_opt.copy() |
| opt["rerank_cefr"] = cefr |
| opt["inference"] = inference_type |
| opt["gpu"] |
|
|
| |
| opt["rerank_tokenizer"] = "distilroberta-base" |
| opt["rerank_model"] = "complexity_model" |
| opt["rerank_model_device"] = "cuda" |
| opt["penalty_stddev"] = 2 |
| opt["filter_path"] = "data/filter.txt" |
|
|
| |
| opt["wordlist_path"] = "data/sample_wordlist.txt" |
| |
| |
| opt["beam_size"] = 20 |
| opt["topk"] = 40 |
|
|
| human_agent = GradioHumanAgent(opt) |
| model_agent = ControllableBlender(opt) |
| world = DialogPartnerWorld(opt, [human_agent, model_agent]) |
|
|
| return human_agent, world |
|
|
| def chat(user_input, cefr, inference_type, history): |
| if conversation_state["world"] is None: |
|
|
| human_agent, world = init_world(cefr, inference_type) |
| conversation_state["world"] = world |
| conversation_state["human_agent"] = human_agent |
|
|
| print("🔥 Warming up...") |
| conversation_state["human_agent"].msg = "Hello" |
| conversation_state["world"].agents[1].opt['beam_size'] = 1 |
| conversation_state["world"].agents[1].opt['topk'] = 10 |
| conversation_state["world"].parley() |
| print("✅ Warmup complete.") |
| |
| conversation_state["human_agent"].msg = user_input |
| conversation_state["world"].parley() |
|
|
| bot_reply = conversation_state["world"].acts[1].get("text", "") |
| history.append([user_input, bot_reply.strip()]) |
| return history, history |
|
|
| def reset_chat(): |
| conversation_state["world"] = None |
| conversation_state["human_agent"] = None |
| return [] |
|
|
| with gr.Blocks() as demo: |
| cefr = gr.Dropdown(["A1", "A2", "B1", "B2", "C1", "C2"], label="CEFR", value="B2") |
| inference_type = gr.Dropdown(["rerank", "vocab"], label="Inference", value="rerank") |
| user_input = gr.Textbox(label="your message") |
| chatbot = gr.Chatbot(label="Controllable Complexity Chatbot") |
| send_btn = gr.Button("Send") |
|
|
| state = gr.State([]) |
|
|
| def user_chat(message, cefr_level, infer_type, history): |
| |
| print("Received:", user_input) |
| new_history, _ = chat(message, cefr_level, infer_type, history) |
| print("Received:", user_input) |
| return new_history, new_history |
|
|
| send_btn.click( |
| fn=user_chat, |
| inputs=[user_input, cefr, inference_type, state], |
| outputs=[chatbot, state] |
| ) |
|
|
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False) |
|
|