File size: 2,241 Bytes
6cd0af6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | """
Regulatory Capacity Classifier - Usage Example
This script demonstrates how to use the trained model for inference.
"""
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# Configuration
MODEL_PATH = "./final_model"
LABELS = ['Cog-Evaluate', 'Cog-Explain', 'Cog-Generate', 'Cog-Reason', 'Meta-Monitor', 'Meta-Orient', 'Meta-Plan', 'Socio-Coordinate', 'Socio-Encourage', 'Socio-Feedback', 'TE-Act', 'TE-Report']
def load_model():
"""Load the trained model and tokenizer."""
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()
return tokenizer, model
def predict(text, tokenizer, model, threshold=0.5):
"""Predict regulatory capacity labels for a given text."""
# Tokenize
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=128,
padding=True
)
# Predict
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs.logits)
predictions = (probs > threshold).int()
# Map to labels
predicted_labels = [LABELS[i] for i in range(len(LABELS)) if predictions[0][i] == 1]
confidences = {LABELS[i]: float(probs[0][i]) for i in range(len(LABELS))}
return predicted_labels, confidences
def main():
print("Loading model...")
tokenizer, model = load_model()
# Example texts
test_texts = [
"I think we should evaluate our approach before moving forward.",
"Let's coordinate our tasks and divide the work equally.",
"Good job everyone! We're making great progress.",
"Can you explain why you chose that option?",
"I'll write down our conclusions in the report."
]
print("\n" + "="*60)
print("Regulatory Capacity Predictions")
print("="*60)
for text in test_texts:
labels, confidences = predict(text, tokenizer, model)
print(f"\nText: {text}")
print(f"Predicted Labels: {labels}")
print(f"Top Confidences: {{k: f'{v:.3f}' for k, v in sorted(confidences.items(), key=lambda x: -x[1])[:3]}}")
if __name__ == "__main__":
main()
|