btnotpt's picture
Update app.py
b2ecde0 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import gradio as gr
import os
# Base model and adapter path
model_id = "btnotpt/ielts_eval"
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# BitsAndBytes config for 4bit loading
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = tokenizer.eos_token
base = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto"
)
model = PeftModel.from_pretrained(base, adapter_path)
model.eval()
# Inference function
def generate_feedback(topic, essay):
prompt = f"""### Evalutate the given IELTS Task 2 essay in response to a topic and provide feedback.
Topic:
{topic}
### Essay:
{essay}
### Feedback:"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=700,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response[len(prompt):] # strip the prompt part
# Gradio app
gr.Interface(
fn=generate_feedback,
inputs=[
gr.Textbox(label="IELTS Essay Topic", placeholder="Enter the essay topic here", lines=2),
gr.Textbox(label="Student Essay", placeholder="Paste the student's essay here", lines=20)
],
outputs="text",
title="IELTS Essay Feedback Generator (Mistral + LoRA)"
).launch()