Spaces:
Sleeping
Sleeping
File size: 5,531 Bytes
9c4c110 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import os
from typing import Dict, Union
# --- Model and Instruction Configuration ---
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
SYSTEM_INSTRUCTION = """
You are a strict grading assistant.
Return ONLY a JSON object with:
- accuracy (float 0-10)
- grade (string A-D)
- feedback (string)
"""
# ------------------------------------------
# Load Model and Tokenizer once for the entire application
try:
print(f"Loading model {MODEL_ID} for Gradio app...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
TERMINATORS = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
MODEL_LOADED = True
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
print("Gradio will run, but the grading function will return an error.")
MODEL_LOADED = False
tokenizer, model, TERMINATORS = None, None, None
def grade_response(student_response: str) -> Union[Dict, str]:
"""
Core grading function (same as before)
"""
if not MODEL_LOADED:
return {"accuracy": 0.0, "grade": "Error", "feedback": "Model failed to load. Check console for details."}
# 1. Construct the Message List
messages = [
{"role": "system", "content": SYSTEM_INSTRUCTION},
{"role": "user", "content": f"Student response to grade: '{student_response}'"},
]
# 2. Apply Chat Template and Tokenize
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# 3. Generate the Output
try:
output_ids = model.generate(
input_ids,
max_new_tokens=200,
eos_token_id=TERMINATORS,
do_sample=True,
temperature=0.5,
top_p=0.9,
)
except Exception as e:
return {"accuracy": 0.0, "grade": "Error", "feedback": f"Generation error: {e}"}
# 4. Decode the Raw Response
raw_response = tokenizer.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
# 5. Parse the JSON Output
try:
start_index = raw_response.find('{')
end_index = raw_response.rfind('}') + 1
json_string = raw_response[start_index:end_index]
return json.loads(json_string)
except json.JSONDecodeError:
# If parsing fails, return a structured error response
return {"accuracy": 0.0, "grade": "Error", "feedback": f"JSON Decode Error. Raw: {raw_response[:200]}..."}
# --- Gradio Wrapper Function ---
def gradio_grade_wrapper(student_response: str) -> tuple[float, str, str]:
"""
Wraps the core grading function to match the required Gradio outputs.
"""
result = grade_response(student_response)
# Check if the result is a dictionary (the expected structured output)
if isinstance(result, dict):
# Gradio outputs: (accuracy, grade, feedback)
return (
result.get("accuracy", 0.0),
result.get("grade", "N/A"),
result.get("feedback", "No feedback generated.")
)
else:
# Should not happen if error handling in grade_response is correct,
# but here for extreme robustness.
return (0.0, "ERROR", str(result))
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="LLM Essay Grader") as demo:
gr.Markdown("# 📝 LLM Essay Grading Assistant (Llama-3.2-1B-Instruct)")
gr.Markdown(
"Enter a student's response below to receive an automated grade, accuracy score, and feedback "
"from the Llama-3.2-1B-Instruct model."
)
# Input Component
with gr.Row():
student_input = gr.Textbox(
label="Student Response to Grade",
placeholder="E.g., 'The main causes of the World War 2 were economic depression and poor leadership.'",
lines=5,
scale=3
)
grade_button = gr.Button("Submit for Grading", scale=1, variant="primary")
gr.Markdown("---")
gr.Markdown("## Grading Results")
# Output Components arranged in a Row for visual clarity
with gr.Row():
accuracy_output = gr.Number(label="Accuracy (0-10)", interactive=False, precision=1)
grade_output = gr.Textbox(label="Grade (A-D)", interactive=False)
feedback_output = gr.Textbox(
label="Detailed Feedback",
interactive=False,
lines=4,
max_lines=10
)
# Event Listener: Connect the button click to the wrapper function
grade_button.click(
fn=gradio_grade_wrapper,
inputs=[student_input],
outputs=[accuracy_output, grade_output, feedback_output]
)
# Add Examples
gr.Examples(
examples=[
["The Earth is a cube and its main moon is Mars, which proves that gravity is fake."],
["A proper noun is a name used to designate a single, specific person, place, or thing, and is always capitalized."],
["The two main drivers of climate change are the burning of fossil fuels (releasing greenhouse gases) and deforestation."],
],
inputs=student_input,
)
# Launch the Gradio App
if __name__ == "__main__":
demo.launch() |