faturbbx's picture
Update app.py
20cec0e verified
import gradio as gr
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel, BertPreTrainedModel, BertConfig
import shap
import numpy as np
import os
import re
# --- 1. Define the Model and Other Classes ---
class BertForEssayScoring(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(0.3)
self.regressor = nn.Linear(config.hidden_size, 1)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, labels=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = self.dropout(outputs.pooler_output)
score = self.regressor(pooled_output).squeeze(-1)
loss = None
if labels is not None:
loss_fn = nn.MSELoss()
loss = loss_fn(score, labels)
return (loss, score) if loss is not None else (score,)
# --- 2. Load the Model and Tokenizer ---
# Set the path to your fine-tuned model directory
path_to_model = './BERT_MODEL' # Make sure this matches the directory name you upload
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizer and model from the local directory
tokenizer = BertTokenizer.from_pretrained(path_to_model)
model = BertForEssayScoring.from_pretrained(path_to_model)
model.to(device)
model.eval()
# --- 3. Create the Prediction and SHAP Explanation Function ---
# This function will be called by the SHAP explainer
def predict_for_shap(inputs):
# Ensure inputs are a list of strings
if isinstance(inputs, np.ndarray):
inputs = inputs.tolist()
encoded = tokenizer(
inputs,
return_tensors='pt',
truncation=True,
padding=True,
max_length=512
).to(device)
with torch.no_grad():
score = model(**encoded)[0]
return score.cpu().numpy()
# Initialize the SHAP explainer
explainer = shap.Explainer(predict_for_shap, tokenizer)
# This is the main function for the Gradio interface
def get_prediction_and_explanation(essay):
# Clean the text for SHAP, similar to your notebook
essay_clean = re.sub(r'[^\w\s]', '', essay)
# Get SHAP values
shap_values = explainer([essay_clean])
# --- Process SHAP values for highlighting ---
tokens = shap_values.data[0]
values = shap_values.values[0]
# Normalize values for better color representation
normalized_values = (values - np.min(values)) / (np.max(values) - np.min(values))
highlighted_text = []
for token, value in zip(tokens, normalized_values):
# We assign a simple positive/negative label for coloring
label = "POS" if value > 0.5 else "NEG"
highlighted_text.append((token, label))
# --- Get the denormalized score prediction ---
# These min/max scores are from your notebook. You might need to adjust them.
min_score = 1
max_score = 6
# Get the raw normalized score from the model
normalized_pred = predict_for_shap(essay)[0]
# Denormalize the score
denormalized_score = normalized_pred * (max_score - min_score) + min_score
final_score = round(float(denormalized_score), 2)
# --- Get Top 10 Influential Words ---
word_impacts = sorted(list(zip(tokens, values)), key=lambda x: x[1])
top_positive = word_impacts[-10:][::-1]
top_negative = word_impacts[:10]
strong_words = "\n".join([f"- {word} ({impact:.3f})" for word, impact in top_positive])
weak_words = "\n".join([f"- {word} ({impact:.3f})" for word, impact in top_negative])
return final_score, {"text": essay, "entities": highlighted_text}, strong_words, weak_words
# --- 4. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📝 Automated Essay Scoring with Explanation")
gr.Markdown("Enter an essay below to get its predicted score. The model will also highlight words that influenced the score.")
with gr.Row():
with gr.Column(scale=2):
essay_input = gr.Textbox(lines=15, label="Essay Input", placeholder="Paste the essay here...")
submit_button = gr.Button("Get Score", variant="primary")
with gr.Column(scale=1):
predicted_score = gr.Number(label="Predicted Score")
gr.Markdown("### Words that Strengthen the Score")
strong_words_output = gr.Textbox(label="Top 10 Positive Words", lines=10)
gr.Markdown("### Words that Weaken the Score")
weak_words_output = gr.Textbox(label="Top 10 Negative Words", lines=10)
gr.Markdown("---")
gr.Markdown("### Score Influence Visualization")
highlight_output = gr.HighlightedText(
label="Word Importance",
color_map={"POS": "green", "NEG": "red"},
show_legend=True
)
submit_button.click(
fn=get_prediction_and_explanation,
inputs=essay_input,
outputs=[predicted_score, highlight_output, strong_words_output, weak_words_output]
)
gr.Examples(
examples=[
["I believe that technology has made the world a better place. It connects people from all over the globe and provides access to a vast amount of information. However, it also has its downsides, such as the spread of misinformation and the potential for addiction."],
["The main character in the novel was not very believable. His motivations were unclear, and his actions often seemed random. This made it difficult to connect with the story on an emotional level."]
],
inputs=essay_input
)
demo.launch()