Spaces:
Runtime error
Runtime error
Update OTF_ComplexControl.py
Browse files- OTF_ComplexControl.py +27 -11
OTF_ComplexControl.py
CHANGED
|
@@ -4,11 +4,12 @@ import json
|
|
| 4 |
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
|
| 5 |
import sys
|
| 6 |
sys.path.append(f'../source')
|
| 7 |
-
import
|
| 8 |
-
import helpers
|
| 9 |
import torch
|
| 10 |
from huggingface_hub import snapshot_download
|
| 11 |
from huggingface_hub import login
|
|
|
|
|
|
|
| 12 |
|
| 13 |
parser = argparse.ArgumentParser(description="Generate responses on a CEFR level")
|
| 14 |
parser.add_argument("--n", type=int, default=10, help="Number of dialog contexts. Default: %(default)s")
|
|
@@ -77,7 +78,7 @@ def detect_cefr_level(text: str) -> str:
|
|
| 77 |
|
| 78 |
"""Response Generation Script"""
|
| 79 |
def get_response(prompt):
|
| 80 |
-
response_list =
|
| 81 |
response_str = "".join(response_list) if isinstance(response_list, list) else str(response_list)
|
| 82 |
return parse_response(response_str)
|
| 83 |
|
|
@@ -101,23 +102,22 @@ llm_model = AutoModelForCausalLM.from_pretrained(
|
|
| 101 |
responses = []
|
| 102 |
conversation_history = []
|
| 103 |
MAX_TURNS = 5 # Limit the number of turns to keep context manageable
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
if user_input.
|
| 108 |
-
|
| 109 |
|
| 110 |
# 2) Detect CEFR from input context
|
| 111 |
detected_level = detect_cefr_level(user_input)
|
| 112 |
print(f"[DEBUG] Detected CEFR = {detected_level} for context: {user_input}")
|
| 113 |
|
| 114 |
# 3) Build prompt using detected CEFR
|
| 115 |
-
|
| 116 |
conversation_history.append({"role": "user", "text": user_input, "CEFR": detected_level})
|
| 117 |
recent_turns = conversation_history[-MAX_TURNS*2:] # *2 because each turn has user+model
|
| 118 |
|
| 119 |
item = {"context": recent_turns, "CEFR": detected_level, "response": ""}
|
| 120 |
-
item =
|
| 121 |
print(f"[DEBUG] Prompt for response generation: {item['prompt']}")
|
| 122 |
|
| 123 |
# 4) Generate response
|
|
@@ -125,4 +125,20 @@ while True:
|
|
| 125 |
print(f"[{detected_level}] {response}")
|
| 126 |
|
| 127 |
# 5) Update conversation history
|
| 128 |
-
conversation_history.append({"role": "model", "text": response, "CEFR": detected_level})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
|
| 5 |
import sys
|
| 6 |
sys.path.append(f'../source')
|
| 7 |
+
import cefr_utils
|
|
|
|
| 8 |
import torch
|
| 9 |
from huggingface_hub import snapshot_download
|
| 10 |
from huggingface_hub import login
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import os
|
| 13 |
|
| 14 |
parser = argparse.ArgumentParser(description="Generate responses on a CEFR level")
|
| 15 |
parser.add_argument("--n", type=int, default=10, help="Number of dialog contexts. Default: %(default)s")
|
|
|
|
| 78 |
|
| 79 |
"""Response Generation Script"""
|
| 80 |
def get_response(prompt):
|
| 81 |
+
response_list = cefr_utils.generate(llm_model, llm_tokenizer, [prompt])
|
| 82 |
response_str = "".join(response_list) if isinstance(response_list, list) else str(response_list)
|
| 83 |
return parse_response(response_str)
|
| 84 |
|
|
|
|
| 102 |
responses = []
|
| 103 |
conversation_history = []
|
| 104 |
MAX_TURNS = 5 # Limit the number of turns to keep context manageable
|
| 105 |
+
def chat(user_input):
|
| 106 |
+
global conversation_history
|
| 107 |
+
|
| 108 |
+
if not user_input.strip():
|
| 109 |
+
return conversation_history, "Please enter a message."
|
| 110 |
|
| 111 |
# 2) Detect CEFR from input context
|
| 112 |
detected_level = detect_cefr_level(user_input)
|
| 113 |
print(f"[DEBUG] Detected CEFR = {detected_level} for context: {user_input}")
|
| 114 |
|
| 115 |
# 3) Build prompt using detected CEFR
|
|
|
|
| 116 |
conversation_history.append({"role": "user", "text": user_input, "CEFR": detected_level})
|
| 117 |
recent_turns = conversation_history[-MAX_TURNS*2:] # *2 because each turn has user+model
|
| 118 |
|
| 119 |
item = {"context": recent_turns, "CEFR": detected_level, "response": ""}
|
| 120 |
+
item = cefr_utils.get_CEFR_prompt(item, apply_chat_template=llm_tokenizer.apply_chat_template)
|
| 121 |
print(f"[DEBUG] Prompt for response generation: {item['prompt']}")
|
| 122 |
|
| 123 |
# 4) Generate response
|
|
|
|
| 125 |
print(f"[{detected_level}] {response}")
|
| 126 |
|
| 127 |
# 5) Update conversation history
|
| 128 |
+
conversation_history.append({"role": "model", "text": response, "CEFR": detected_level})
|
| 129 |
+
|
| 130 |
+
gradio_history = []
|
| 131 |
+
for turn in conversation_history:
|
| 132 |
+
if turn["role"] == "user":
|
| 133 |
+
gradio_history.append((turn['text'], None))
|
| 134 |
+
else:
|
| 135 |
+
gradio_history[-1] = (gradio_history[-1][0], turn["text"])
|
| 136 |
+
|
| 137 |
+
return gradio_history, ""
|
| 138 |
+
|
| 139 |
+
with gr.Blocks() as demo:
|
| 140 |
+
chatbot = gr.Chatbot(label="Adaptive CEFR chatbot")
|
| 141 |
+
msg = gr.Textbox(placeholder="Type your message here...")
|
| 142 |
+
msg.submit(chat, inputs=msg, outputs=[chatbot, msg], clear_on_submit=True)
|
| 143 |
+
|
| 144 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)
|