shivansarora's picture
Update app.py
4996f90 verified
import argparse
from tqdm import tqdm
import json
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append(f'../source')
import cefr_utils
import torch
from huggingface_hub import snapshot_download
from huggingface_hub import login
import gradio as gr
import os
description = {
"C2": "Has a good command of idiomatic expressions and colloquialisms with awareness of connotative levels of meaning. Can convey finer shades of meaning precisely by using, with reasonable accuracy, a wide range of modification devices. Can backtrack and restructure around a difficulty so smoothly that the interlocutor is hardly aware of it.",
"C1": "Can express themselves fluently and spontaneously, almost effortlessly. Has a good command of a broad lexical repertoire allowing gaps to be readily overcome with circumlocutions. There is little obvious searching for expressions or avoidance strategies; only a conceptually difficult subject can hinder a natural, smooth flow of language.",
"B2": "Can interact with a degree of fluency and spontaneity that makes regular interaction, and sustained relationships with users of the target language, quite possible without imposing strain on either party. Can highlight the personal significance of events and experiences, and account for and sustain views clearly by providing relevant explanations and arguments.",
"B1": "Can communicate with some confidence on familiar routine and non-routine matters related to their interests and professional field. Can exchange, check and confirm information, deal with less routine situations and explain why something is a problem. Can express thoughts on more abstract, cultural topics such as films, books, music, etc.",
"A2": "Can interact with reasonable ease in structured situations and short conversations, provided the other person helps if necessary. Can manage simple, routine exchanges without undue effort; can ask and answer questions and exchange ideas and information on familiar topics in predictable everyday situations.",
"A1": "Can interact in a simple way but communication is totally dependent on repetition at a slower rate, rephrasing and repair. Can ask and answer simple questions, initiate and respond to simple statements in areas of immediate need or on very familiar topics."
}
def parse_response(response, format="A: "):
if format in response:
return response[response.index(format)+len(format):]
return response
"""CEFR Classifier Script"""
from CEFR_evaluator.level_model import LevelEstimaterClassification
ckpt_path = "CEFR_evaluator/level_estimator.ckpt"
token = os.environ.get("Token1")
login(token=token)
snapshot_download(repo_id="shivansarora/cefr-evaluator", local_dir="CEFR_evaluator")
cefr_model = LevelEstimaterClassification.load_from_checkpoint(
ckpt_path,
pretrained_model="bert-base-cased",
lm_layer=11,
num_labels=6,
alpha=0.2,
strict=False,
with_loss_weight=False,
corpus_path=None,
test_corpus_path=None,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
cefr_model.load_state_dict(state_dict, strict=False) # ignore extra keys
cefr_model.eval().to(device)
cefr_model.eval()
cefr_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
cefr_model.to(device)
labels = ["A1", "A2", "B1", "B2", "C1", "C2"]
def detect_cefr_level(text: str) -> str:
"""Detect the CEFR level of a given text."""
inputs = cefr_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
logits = cefr_model(inputs)
if isinstance(logits, tuple):
_, preds, _ = logits
else:
preds = logits
pred_label = labels[int(preds[0])]
return pred_label
"""Response Generation Script"""
def get_response(prompt):
response_list = cefr_utils.generate(llm_model, llm_tokenizer, [prompt])
response_str = "".join(response_list) if isinstance(response_list, list) else str(response_list)
return parse_response(response_str)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model_string = "meta-llama/Meta-Llama-3-8B-Instruct"
llm_tokenizer = AutoTokenizer.from_pretrained(model_string)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_model = AutoModelForCausalLM.from_pretrained(
model_string,
device_map="auto",
quantization_config=bnb_config
)
responses = []
conversation_history = []
MAX_TURNS = 5 # Limit the number of turns to keep context manageable
def chat(user_input):
global conversation_history
if not user_input.strip():
return conversation_history, "Please enter a message."
# 2) Detect CEFR from input context
detected_level = detect_cefr_level(user_input)
print(f"[DEBUG] Detected CEFR = {detected_level} for context: {user_input}")
# 3) Build prompt using detected CEFR
conversation_history.append({"role": "user", "text": user_input, "CEFR": detected_level})
recent_turns = conversation_history[-MAX_TURNS*2:] # *2 because each turn has user+model
item = {"context": recent_turns, "CEFR": detected_level, "response": ""}
item = cefr_utils.get_CEFR_prompt(item, apply_chat_template=llm_tokenizer.apply_chat_template)
print(f"[DEBUG] Prompt for response generation: {item['prompt']}")
# 4) Generate response
response = get_response(item['prompt'])
print(f"[{detected_level}] {response}")
# 5) Update conversation history
conversation_history.append({"role": "model", "text": response, "CEFR": detected_level})
gradio_history = []
for turn in conversation_history:
if turn["role"] == "user":
gradio_history.append((turn['text'], None))
else:
gradio_history[-1] = (gradio_history[-1][0], turn["text"])
return gradio_history, ""
with gr.Blocks() as demo:
chatbot = gr.Chatbot(label="Adaptive CEFR chatbot")
msg = gr.Textbox(placeholder="Type your message here...")
msg.submit(chat, inputs=msg, outputs=[chatbot, msg])
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, ssr_mode=False)