| |
|
|
| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| model = None |
| tokenizer = None |
| label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"} |
|
|
| def load_model(): |
| """Load your trained model here""" |
| global model, tokenizer |
| |
| try: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| logger.warning("Using placeholder model loading - replace with your actual model!") |
| |
| |
| model_name = "distilbert-base-uncased" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| model_name, |
| num_labels=3, |
| ignore_mismatched_sizes=True |
| ) |
| |
| logger.info("Model loaded successfully") |
| return "Model loaded successfully!" |
| |
| except Exception as e: |
| logger.error(f"Error loading model: {e}") |
| return f"Error loading model: {e}" |
|
|
| def classify_solution(question: str, solution: str): |
| """ |
| Classify the math solution |
| Returns: (classification_label, confidence_score, explanation) |
| """ |
| if not question.strip() or not solution.strip(): |
| return "Please fill in both fields", 0.0, "" |
| |
| if not model or not tokenizer: |
| return "Model not loaded", 0.0, "" |
| |
| try: |
| |
| text_input = f"Question: {question}\nSolution: {solution}" |
| |
| |
| inputs = tokenizer( |
| text_input, |
| return_tensors="pt", |
| truncation=True, |
| padding=True, |
| max_length=512 |
| ) |
| |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| predicted_class = torch.argmax(predictions, dim=-1).item() |
| confidence = predictions[0][predicted_class].item() |
| |
| classification = label_mapping[predicted_class] |
| |
| |
| explanations = { |
| 0: "The mathematical approach and calculations are both sound.", |
| 1: "The approach or understanding has fundamental issues.", |
| 2: "The approach is correct, but there are calculation errors." |
| } |
| |
| explanation = explanations[predicted_class] |
| |
| return classification, f"{confidence:.2%}", explanation |
| |
| except Exception as e: |
| logger.error(f"Error during classification: {e}") |
| return f"Classification error: {str(e)}", "0%", "" |
|
|
| |
| load_model() |
|
|
| |
| with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app: |
| gr.Markdown("# 🧮 Math Solution Classifier") |
| gr.Markdown("Classify math solutions as correct, conceptually flawed, or computationally flawed.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| question_input = gr.Textbox( |
| label="Math Question", |
| placeholder="e.g., Solve for x: 2x + 5 = 13", |
| lines=3 |
| ) |
| |
| solution_input = gr.Textbox( |
| label="Proposed Solution", |
| placeholder="e.g., 2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4", |
| lines=5 |
| ) |
| |
| classify_btn = gr.Button("Classify Solution", variant="primary") |
| |
| with gr.Column(): |
| classification_output = gr.Textbox(label="Classification", interactive=False) |
| confidence_output = gr.Textbox(label="Confidence", interactive=False) |
| explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3) |
| |
| |
| gr.Examples( |
| examples=[ |
| [ |
| "Solve for x: 2x + 5 = 13", |
| "2x + 5 = 13\n2x = 13 - 5\n2x = 8\nx = 4" |
| ], |
| [ |
| "Find the derivative of f(x) = x²", |
| "f'(x) = 2x + 1" |
| ], |
| [ |
| "What is 15% of 200?", |
| "15% = 15/100 = 0.15\n0.15 × 200 = 30" |
| ] |
| ], |
| inputs=[question_input, solution_input] |
| ) |
| |
| classify_btn.click( |
| fn=classify_solution, |
| inputs=[question_input, solution_input], |
| outputs=[classification_output, confidence_output, explanation_output] |
| ) |
|
|
| if __name__ == "__main__": |
| app.launch() |