| | """ |
| | Regulatory Capacity Classifier - Usage Example |
| | |
| | This script demonstrates how to use the trained model for inference. |
| | """ |
| |
|
| | import torch |
| | from transformers import BertTokenizer, BertForSequenceClassification |
| |
|
| | |
| | MODEL_PATH = "./final_model" |
| | LABELS = ['Cog-Evaluate', 'Cog-Explain', 'Cog-Generate', 'Cog-Reason', 'Meta-Monitor', 'Meta-Orient', 'Meta-Plan', 'Socio-Coordinate', 'Socio-Encourage', 'Socio-Feedback', 'TE-Act', 'TE-Report'] |
| |
|
| | def load_model(): |
| | """Load the trained model and tokenizer.""" |
| | tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) |
| | model = BertForSequenceClassification.from_pretrained(MODEL_PATH) |
| | model.eval() |
| | return tokenizer, model |
| |
|
| | def predict(text, tokenizer, model, threshold=0.5): |
| | """Predict regulatory capacity labels for a given text.""" |
| | |
| | inputs = tokenizer( |
| | text, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=128, |
| | padding=True |
| | ) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | probs = torch.sigmoid(outputs.logits) |
| | predictions = (probs > threshold).int() |
| | |
| | |
| | predicted_labels = [LABELS[i] for i in range(len(LABELS)) if predictions[0][i] == 1] |
| | confidences = {LABELS[i]: float(probs[0][i]) for i in range(len(LABELS))} |
| | |
| | return predicted_labels, confidences |
| |
|
| | def main(): |
| | print("Loading model...") |
| | tokenizer, model = load_model() |
| | |
| | |
| | test_texts = [ |
| | "I think we should evaluate our approach before moving forward.", |
| | "Let's coordinate our tasks and divide the work equally.", |
| | "Good job everyone! We're making great progress.", |
| | "Can you explain why you chose that option?", |
| | "I'll write down our conclusions in the report." |
| | ] |
| | |
| | print("\n" + "="*60) |
| | print("Regulatory Capacity Predictions") |
| | print("="*60) |
| | |
| | for text in test_texts: |
| | labels, confidences = predict(text, tokenizer, model) |
| | print(f"\nText: {text}") |
| | print(f"Predicted Labels: {labels}") |
| | print(f"Top Confidences: {{k: f'{v:.3f}' for k, v in sorted(confidences.items(), key=lambda x: -x[1])[:3]}}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|