import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch import json import os from typing import Dict, Union # --- Model and Instruction Configuration --- MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" SYSTEM_INSTRUCTION = """ You are a strict grading assistant. Return ONLY a JSON object with: - accuracy (float 0-10) - grade (string A-D) - feedback (string) """ # ------------------------------------------ # Load Model and Tokenizer once for the entire application try: print(f"Loading model {MODEL_ID} for Gradio app...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) TERMINATORS = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] MODEL_LOADED = True except Exception as e: print(f"Error loading model or tokenizer: {e}") print("Gradio will run, but the grading function will return an error.") MODEL_LOADED = False tokenizer, model, TERMINATORS = None, None, None def grade_response(student_response: str) -> Union[Dict, str]: """ Core grading function (same as before) """ if not MODEL_LOADED: return {"accuracy": 0.0, "grade": "Error", "feedback": "Model failed to load. Check console for details."} # 1. Construct the Message List messages = [ {"role": "system", "content": SYSTEM_INSTRUCTION}, {"role": "user", "content": f"Student response to grade: '{student_response}'"}, ] # 2. Apply Chat Template and Tokenize input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # 3. Generate the Output try: output_ids = model.generate( input_ids, max_new_tokens=200, eos_token_id=TERMINATORS, do_sample=True, temperature=0.5, top_p=0.9, ) except Exception as e: return {"accuracy": 0.0, "grade": "Error", "feedback": f"Generation error: {e}"} # 4. Decode the Raw Response raw_response = tokenizer.decode( output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True ).strip() # 5. Parse the JSON Output try: start_index = raw_response.find('{') end_index = raw_response.rfind('}') + 1 json_string = raw_response[start_index:end_index] return json.loads(json_string) except json.JSONDecodeError: # If parsing fails, return a structured error response return {"accuracy": 0.0, "grade": "Error", "feedback": f"JSON Decode Error. Raw: {raw_response[:200]}..."} # --- Gradio Wrapper Function --- def gradio_grade_wrapper(student_response: str) -> tuple[float, str, str]: """ Wraps the core grading function to match the required Gradio outputs. """ result = grade_response(student_response) # Check if the result is a dictionary (the expected structured output) if isinstance(result, dict): # Gradio outputs: (accuracy, grade, feedback) return ( result.get("accuracy", 0.0), result.get("grade", "N/A"), result.get("feedback", "No feedback generated.") ) else: # Should not happen if error handling in grade_response is correct, # but here for extreme robustness. return (0.0, "ERROR", str(result)) # --- Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Soft(), title="LLM Essay Grader") as demo: gr.Markdown("# 📝 LLM Essay Grading Assistant (Llama-3.2-1B-Instruct)") gr.Markdown( "Enter a student's response below to receive an automated grade, accuracy score, and feedback " "from the Llama-3.2-1B-Instruct model." ) # Input Component with gr.Row(): student_input = gr.Textbox( label="Student Response to Grade", placeholder="E.g., 'The main causes of the World War 2 were economic depression and poor leadership.'", lines=5, scale=3 ) grade_button = gr.Button("Submit for Grading", scale=1, variant="primary") gr.Markdown("---") gr.Markdown("## Grading Results") # Output Components arranged in a Row for visual clarity with gr.Row(): accuracy_output = gr.Number(label="Accuracy (0-10)", interactive=False, precision=1) grade_output = gr.Textbox(label="Grade (A-D)", interactive=False) feedback_output = gr.Textbox( label="Detailed Feedback", interactive=False, lines=4, max_lines=10 ) # Event Listener: Connect the button click to the wrapper function grade_button.click( fn=gradio_grade_wrapper, inputs=[student_input], outputs=[accuracy_output, grade_output, feedback_output] ) # Add Examples gr.Examples( examples=[ ["The Earth is a cube and its main moon is Mars, which proves that gravity is fake."], ["A proper noun is a name used to designate a single, specific person, place, or thing, and is always capitalized."], ["The two main drivers of climate change are the burning of fossil fuels (releasing greenhouse gases) and deforestation."], ], inputs=student_input, ) # Launch the Gradio App if __name__ == "__main__": demo.launch()