Upload app.py
Browse files
app.py
CHANGED
|
@@ -15,18 +15,17 @@ tokenizer = None
|
|
| 15 |
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
|
| 16 |
|
| 17 |
def load_model():
|
| 18 |
-
"""Load your trained LoRA adapter with
|
| 19 |
global model, tokenizer
|
| 20 |
|
| 21 |
try:
|
| 22 |
-
from peft import
|
| 23 |
|
| 24 |
-
# Load the LoRA adapter model for
|
| 25 |
-
model =
|
| 26 |
"./lora_adapter", # Path to your adapter files
|
| 27 |
-
torch_dtype=torch.
|
| 28 |
-
device_map="
|
| 29 |
-
low_cpu_mem_usage=True # Optimize for low memory
|
| 30 |
)
|
| 31 |
|
| 32 |
# Load tokenizer from the same directory
|
|
@@ -37,23 +36,27 @@ def load_model():
|
|
| 37 |
tokenizer.pad_token = tokenizer.eos_token
|
| 38 |
logger.info("Set pad_token to eos_token")
|
| 39 |
|
| 40 |
-
logger.info("LoRA model loaded successfully")
|
| 41 |
-
return "LoRA model loaded successfully!"
|
| 42 |
|
| 43 |
except Exception as e:
|
| 44 |
logger.error(f"Error loading LoRA model: {e}")
|
| 45 |
# Fallback to placeholder for testing
|
| 46 |
logger.warning("Using placeholder model loading - replace with your actual model!")
|
| 47 |
|
| 48 |
-
model_name = "
|
| 49 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 50 |
|
| 51 |
# Fix padding token for fallback model too
|
| 52 |
if tokenizer.pad_token is None:
|
| 53 |
tokenizer.pad_token = tokenizer.eos_token
|
| 54 |
|
| 55 |
-
from transformers import
|
| 56 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
return f"Fallback model loaded. LoRA error: {e}"
|
| 59 |
|
|
@@ -189,7 +192,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 189 |
|
| 190 |
with gr.Column():
|
| 191 |
classification_output = gr.Textbox(label="Classification", interactive=False)
|
| 192 |
-
|
| 193 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
|
| 194 |
|
| 195 |
# Examples
|
|
@@ -214,7 +217,7 @@ with gr.Blocks(title="Math Solution Classifier", theme=gr.themes.Soft()) as app:
|
|
| 214 |
classify_btn.click(
|
| 215 |
fn=classify_solution,
|
| 216 |
inputs=[question_input, solution_input],
|
| 217 |
-
outputs=[classification_output,
|
| 218 |
)
|
| 219 |
|
| 220 |
if __name__ == "__main__":
|
|
|
|
| 15 |
label_mapping = {0: "✅ Correct", 1: "🤔 Conceptually Flawed", 2: "🔢 Computationally Flawed"}
|
| 16 |
|
| 17 |
def load_model():
|
| 18 |
+
"""Load your trained LoRA adapter with classification head"""
|
| 19 |
global model, tokenizer
|
| 20 |
|
| 21 |
try:
|
| 22 |
+
from peft import AutoPeftModelForSequenceClassification # Back to classification
|
| 23 |
|
| 24 |
+
# Load the LoRA adapter model for classification
|
| 25 |
+
model = AutoPeftModelForSequenceClassification.from_pretrained(
|
| 26 |
"./lora_adapter", # Path to your adapter files
|
| 27 |
+
torch_dtype=torch.float16,
|
| 28 |
+
device_map="auto"
|
|
|
|
| 29 |
)
|
| 30 |
|
| 31 |
# Load tokenizer from the same directory
|
|
|
|
| 36 |
tokenizer.pad_token = tokenizer.eos_token
|
| 37 |
logger.info("Set pad_token to eos_token")
|
| 38 |
|
| 39 |
+
logger.info("LoRA classification model loaded successfully")
|
| 40 |
+
return "LoRA classification model loaded successfully!"
|
| 41 |
|
| 42 |
except Exception as e:
|
| 43 |
logger.error(f"Error loading LoRA model: {e}")
|
| 44 |
# Fallback to placeholder for testing
|
| 45 |
logger.warning("Using placeholder model loading - replace with your actual model!")
|
| 46 |
|
| 47 |
+
model_name = "distilbert-base-uncased" # Simple fallback
|
| 48 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 49 |
|
| 50 |
# Fix padding token for fallback model too
|
| 51 |
if tokenizer.pad_token is None:
|
| 52 |
tokenizer.pad_token = tokenizer.eos_token
|
| 53 |
|
| 54 |
+
from transformers import AutoModelForSequenceClassification
|
| 55 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 56 |
+
model_name,
|
| 57 |
+
num_labels=3,
|
| 58 |
+
ignore_mismatched_sizes=True
|
| 59 |
+
)
|
| 60 |
|
| 61 |
return f"Fallback model loaded. LoRA error: {e}"
|
| 62 |
|
|
|
|
| 192 |
|
| 193 |
with gr.Column():
|
| 194 |
classification_output = gr.Textbox(label="Classification", interactive=False)
|
| 195 |
+
confidence_output = gr.Textbox(label="Confidence", interactive=False)
|
| 196 |
explanation_output = gr.Textbox(label="Explanation", interactive=False, lines=3)
|
| 197 |
|
| 198 |
# Examples
|
|
|
|
| 217 |
classify_btn.click(
|
| 218 |
fn=classify_solution,
|
| 219 |
inputs=[question_input, solution_input],
|
| 220 |
+
outputs=[classification_output, confidence_output, explanation_output]
|
| 221 |
)
|
| 222 |
|
| 223 |
if __name__ == "__main__":
|