shivansarora commited on
Commit
6632815
Β·
verified Β·
1 Parent(s): 70e00b3

Update OTF_ComplexControl.py

Browse files
Files changed (1) hide show
  1. OTF_ComplexControl.py +27 -11
OTF_ComplexControl.py CHANGED
@@ -4,11 +4,12 @@ import json
4
  from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
5
  import sys
6
  sys.path.append(f'../source')
7
- import models
8
- import helpers
9
  import torch
10
  from huggingface_hub import snapshot_download
11
  from huggingface_hub import login
 
 
12
 
13
  parser = argparse.ArgumentParser(description="Generate responses on a CEFR level")
14
  parser.add_argument("--n", type=int, default=10, help="Number of dialog contexts. Default: %(default)s")
@@ -77,7 +78,7 @@ def detect_cefr_level(text: str) -> str:
77
 
78
  """Response Generation Script"""
79
  def get_response(prompt):
80
- response_list = models.generate(llm_model, llm_tokenizer, [prompt])
81
  response_str = "".join(response_list) if isinstance(response_list, list) else str(response_list)
82
  return parse_response(response_str)
83
 
@@ -101,23 +102,22 @@ llm_model = AutoModelForCausalLM.from_pretrained(
101
  responses = []
102
  conversation_history = []
103
  MAX_TURNS = 5 # Limit the number of turns to keep context manageable
104
- while True:
105
- # 1) Get user input
106
- user_input = input("User: ").strip()
107
- if user_input.lower() in ["quit", "exit"]:
108
- break
109
 
110
  # 2) Detect CEFR from input context
111
  detected_level = detect_cefr_level(user_input)
112
  print(f"[DEBUG] Detected CEFR = {detected_level} for context: {user_input}")
113
 
114
  # 3) Build prompt using detected CEFR
115
-
116
  conversation_history.append({"role": "user", "text": user_input, "CEFR": detected_level})
117
  recent_turns = conversation_history[-MAX_TURNS*2:] # *2 because each turn has user+model
118
 
119
  item = {"context": recent_turns, "CEFR": detected_level, "response": ""}
120
- item = helpers.get_CEFR_prompt(item, apply_chat_template=llm_tokenizer.apply_chat_template)
121
  print(f"[DEBUG] Prompt for response generation: {item['prompt']}")
122
 
123
  # 4) Generate response
@@ -125,4 +125,20 @@ while True:
125
  print(f"[{detected_level}] {response}")
126
 
127
  # 5) Update conversation history
128
- conversation_history.append({"role": "model", "text": response, "CEFR": detected_level})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
5
  import sys
6
  sys.path.append(f'../source')
7
+ import cefr_utils
 
8
  import torch
9
  from huggingface_hub import snapshot_download
10
  from huggingface_hub import login
11
+ import gradio as gr
12
+ import os
13
 
14
  parser = argparse.ArgumentParser(description="Generate responses on a CEFR level")
15
  parser.add_argument("--n", type=int, default=10, help="Number of dialog contexts. Default: %(default)s")
 
78
 
79
  """Response Generation Script"""
80
  def get_response(prompt):
81
+ response_list = cefr_utils.generate(llm_model, llm_tokenizer, [prompt])
82
  response_str = "".join(response_list) if isinstance(response_list, list) else str(response_list)
83
  return parse_response(response_str)
84
 
 
102
  responses = []
103
  conversation_history = []
104
  MAX_TURNS = 5 # Limit the number of turns to keep context manageable
105
+ def chat(user_input):
106
+ global conversation_history
107
+
108
+ if not user_input.strip():
109
+ return conversation_history, "Please enter a message."
110
 
111
  # 2) Detect CEFR from input context
112
  detected_level = detect_cefr_level(user_input)
113
  print(f"[DEBUG] Detected CEFR = {detected_level} for context: {user_input}")
114
 
115
  # 3) Build prompt using detected CEFR
 
116
  conversation_history.append({"role": "user", "text": user_input, "CEFR": detected_level})
117
  recent_turns = conversation_history[-MAX_TURNS*2:] # *2 because each turn has user+model
118
 
119
  item = {"context": recent_turns, "CEFR": detected_level, "response": ""}
120
+ item = cefr_utils.get_CEFR_prompt(item, apply_chat_template=llm_tokenizer.apply_chat_template)
121
  print(f"[DEBUG] Prompt for response generation: {item['prompt']}")
122
 
123
  # 4) Generate response
 
125
  print(f"[{detected_level}] {response}")
126
 
127
  # 5) Update conversation history
128
+ conversation_history.append({"role": "model", "text": response, "CEFR": detected_level})
129
+
130
+ gradio_history = []
131
+ for turn in conversation_history:
132
+ if turn["role"] == "user":
133
+ gradio_history.append((turn['text'], None))
134
+ else:
135
+ gradio_history[-1] = (gradio_history[-1][0], turn["text"])
136
+
137
+ return gradio_history, ""
138
+
139
+ with gr.Blocks() as demo:
140
+ chatbot = gr.Chatbot(label="Adaptive CEFR chatbot")
141
+ msg = gr.Textbox(placeholder="Type your message here...")
142
+ msg.submit(chat, inputs=msg, outputs=[chatbot, msg], clear_on_submit=True)
143
+
144
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)