lessoncraft / app.py
Ryan Robson
Improve inference quality
8bd0b76
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
print("πŸ”„ Loading CCISD TEKS Educational Assistant...")
print(" This may take 1-2 minutes on first launch...")
# Load base model + LoRA adapter
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
ADAPTER_MODEL = "robworks-software/ccisd-teks-educator-mistral7b"
print(f"πŸ“₯ Loading base model: {BASE_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f" Using device: {device}")
# Load base model (use bfloat16 for CPU compatibility)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16 if device == "cpu" else torch.float16,
low_cpu_mem_usage=True
)
print(f"πŸ”§ Loading LoRA adapter: {ADAPTER_MODEL}...")
model = PeftModel.from_pretrained(model, ADAPTER_MODEL)
model = model.to(device)
model.eval() # Set to evaluation mode
print("βœ… Model loaded successfully!")
def chat(message, history):
"""
Generate response using Mistral chat format.
Args:
message: Current user message
history: List of [user_msg, bot_msg] pairs
Returns:
Generated response string
"""
# Simplified prompt - just the current message
prompt = f"[INST] You are a Texas TEKS educational expert. Answer this question clearly and helpfully:\n\n{message} [/INST]"
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate response with better parameters
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=300,
min_new_tokens=50,
temperature=0.8,
top_p=0.95,
top_k=50,
do_sample=True,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract answer after [/INST]
if "[/INST]" in response:
response = response.split("[/INST]")[-1].strip()
# Clean up any remaining artifacts
if response.startswith(message):
response = response[len(message):].strip()
return response
# Create Gradio ChatInterface
demo = gr.ChatInterface(
fn=chat,
title="πŸŽ“ CCISD TEKS Educational Assistant",
description="""
**Powered by Mistral-7B fine-tuned on 4,224 enhanced CCISD TEKS examples**
Ask questions about:
- πŸ“š TEKS standard definitions and explanations
- πŸ‘¨β€πŸ« Teaching strategies and lesson planning
- 🌎 Real-world applications
- βœ… Assessment and quiz generation
- 🎯 Learning objectives and prerequisites
- πŸ’‘ Common student misconceptions
*Trained on Texas TEKS standards with real-world educational data.*
""",
examples=[
"What is TEKS standard MATH.A1.3.B?",
"How should I teach the scientific method to 7th grade students?",
"What are real-world applications of linear equations?",
"Create a quiz question to assess understanding of photosynthesis",
"What are common misconceptions students have about fractions?",
"Explain ELAR.9.1.A and suggest a lesson activity",
],
cache_examples=False, # Don't pre-compute example responses (too slow on CPU)
theme="soft",
chatbot=gr.Chatbot(height=500),
textbox=gr.Textbox(placeholder="Ask me about Texas TEKS standards...", container=False, scale=7),
retry_btn="πŸ”„ Retry",
undo_btn="↩️ Undo",
clear_btn="πŸ—‘οΈ Clear",
)
# Launch the app
if __name__ == "__main__":
demo.launch()