afehrjtmhdEGS / model_utils.py
LR36's picture
Update model_utils.py
c54b282 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# All predetermined answers (now inside app.py for simplicity)
PREDETERMINED_ANSWERS = {
"negative reinforcement": {
"definition": "When a behavior increases because it removes/prevents an aversive stimulus.",
"examples": [
"Child stops whining when seatbelt clicks (removal of annoying sound reinforces buckling)",
"Student completes work to avoid teacher reprimand"
],
"key_points": [
"NOT punishment (which decreases behavior)",
"Two types: escape (stop existing stimulus) and avoidance (prevent stimulus)",
"Common in escape-maintained behaviors"
]
},
"positive reinforcement": {
"definition": "When a behavior increases because it produces a rewarding consequence.",
"examples": [
"Child says 'please' and gets a sticker (behavior increases)",
"Employee meets deadline and receives bonus"
],
"key_points": [
"Most effective when immediate and contingent",
"Can be tangible (toys) or social (praise)",
"Should be individualized to the learner"
]
},
"aba": {
"definition": "Applied Behavior Analysis - scientific approach using learning principles to improve behaviors.",
"examples": [
"Teaching communication skills using picture exchange",
"Reducing self-injury through functional assessment"
],
"key_points": [
"Data-driven decision making",
"Breaks skills into teachable steps",
"Gold-standard for autism treatment",
"Focuses on socially significant behaviors"
]
},
"differential reinforcement": {
"definition": "Reinforcing specific behaviors while withholding reinforcement for others.",
"examples": [
"Rewarding quiet hands while ignoring flapping (DRI)",
"Providing attention for polite requests but not whining (DRA)"
],
"key_points": [
"DRA: Alternative behavior",
"DRO: Other behavior (absence of target)",
"DRI: Incompatible behavior",
"Requires consistency to be effective"
]
},
"functional behavior assessment": {
"definition": "Process to identify the purpose/function of a behavior.",
"examples": [
"ABC data collection for aggression",
"Interviews and scatter plots for elopement"
],
"key_points": [
"Identifies antecedents and consequences",
"Reveals maintaining variables",
"Foundation for Behavior Intervention Plans"
]
}
}
def load_model():
"""Load model with optimizations"""
try:
model_name = "google/flan-t5-small" # Smaller = faster
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
device_map="auto", # Auto-detects GPU/CPU
torch_dtype=torch.float16, # Faster if GPU available
low_cpu_mem_usage=True
)
return tokenizer, model
except Exception as e:
print(f"Model failed to load: {e}")
return None, None
def generate_response(question, tokenizer, model):
# Check predetermined answers first (fast path)
question_lower = question.lower().strip()
for key in PREDETERMINED_ANSWERS:
if key in question_lower:
return format_answer(key) # Your formatting function
# Only use model if loaded
if tokenizer and model:
try:
inputs = tokenizer(
f"Answer this ABA question: {question}",
return_tensors="pt",
truncation=True,
max_length=512
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=150, # Shorter = faster
temperature=0.7,
do_sample=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Generation failed: {e}")
return "I can't generate an answer right now. Try asking about: " + ", ".join(PREDETERMINED_ANSWERS.keys())
demo.launch()