| |
| """ |
| Inference script for Pulse Core 1 - Vietnamese Aspect Sentiment Analysis. |
| Loads trained sentiment models from local files and performs predictions. |
| """ |
|
|
| import argparse |
| import joblib |
| import os |
| import glob |
|
|
|
|
| def find_local_models(): |
| """Find all available local sentiment model files""" |
| models = { |
| 'exported': {}, |
| 'runs': {} |
| } |
|
|
| |
| for filename in os.listdir('.'): |
| if filename.endswith('.joblib'): |
| if filename.startswith('uts2017_sentiment_'): |
| models['exported']['uts2017_sentiment'] = filename |
|
|
| |
| sentiment_runs = glob.glob('runs/*/models/UTS2017_Bank_AspectSentiment_*.joblib') |
|
|
| if sentiment_runs: |
| |
| sentiment_runs.sort(key=lambda x: os.path.getmtime(x), reverse=True) |
|
|
| |
| svc_models = [m for m in sentiment_runs if 'SVC' in m] |
| if svc_models: |
| models['runs']['uts2017_sentiment'] = svc_models[0] |
| else: |
| models['runs']['uts2017_sentiment'] = sentiment_runs[0] |
|
|
| return models |
|
|
|
|
| def load_model(model_path): |
| """Load a model from file path""" |
| try: |
| print(f"Loading model from: {model_path}") |
| model = joblib.load(model_path) |
| print(f"Model loaded successfully. Classes: {len(model.classes_)}") |
| return model |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return None |
|
|
|
|
| def predict_text(model, text): |
| """Make prediction on a single text""" |
| try: |
| probabilities = model.predict_proba([text])[0] |
|
|
| |
| top_indices = probabilities.argsort()[-3:][::-1] |
| top_predictions = [] |
| for idx in top_indices: |
| category = model.classes_[idx] |
| prob = probabilities[idx] |
| top_predictions.append((category, prob)) |
|
|
| |
| prediction = top_predictions[0][0] |
| confidence = top_predictions[0][1] |
|
|
| return prediction, confidence, top_predictions |
| except Exception as e: |
| print(f"Error making prediction: {e}") |
| return None, 0, [] |
|
|
|
|
| def interactive_mode(model, dataset_name): |
| """Interactive prediction mode""" |
| print(f"\n{'='*60}") |
| print("INTERACTIVE MODE - VIETNAMESE BANKING ASPECT SENTIMENT ANALYSIS") |
| print(f"{'='*60}") |
| print("Enter Vietnamese banking text to analyze aspect and sentiment (type 'quit' to exit):") |
|
|
| while True: |
| try: |
| user_input = input("\nText: ").strip() |
|
|
| if user_input.lower() in ['quit', 'exit', 'q']: |
| break |
|
|
| if not user_input: |
| continue |
|
|
| prediction, confidence, top_predictions = predict_text(model, user_input) |
|
|
| if prediction: |
| print(f"Predicted category: {prediction}") |
| print(f"Confidence: {confidence:.3f}") |
| print("Top 3 predictions:") |
| for i, (category, prob) in enumerate(top_predictions, 1): |
| print(f" {i}. {category}: {prob:.3f}") |
|
|
| except KeyboardInterrupt: |
| print("\nExiting...") |
| break |
| except Exception as e: |
| print(f"Error: {e}") |
|
|
|
|
| def test_examples(model, dataset_name): |
| """Test model with predefined banking examples""" |
| examples = [ |
| "Tôi muốn mở tài khoản tiết kiệm mới", |
| "Lãi suất vay mua nhà hiện tại quá cao", |
| "Làm thế nào để đăng ký internet banking?", |
| "Chi phí chuyển tiền ra nước ngoài rất đắt", |
| "Ngân hàng ACB có uy tín không?", |
| "Tôi cần hỗ trợ về dịch vụ ngân hàng", |
| "Thẻ tín dụng bị khóa không rõ lý do", |
| "Dịch vụ chăm sóc khách hàng rất tệ", |
| "Khuyến mãi tháng này rất hấp dẫn", |
| "Bảo mật tài khoản có được đảm bảo không?" |
| ] |
|
|
| print("\n" + "="*60) |
| print("TESTING VIETNAMESE BANKING ASPECT SENTIMENT ANALYSIS") |
| print("="*60) |
|
|
| for text in examples: |
| prediction, confidence, top_predictions = predict_text(model, text) |
|
|
| if prediction: |
| print(f"\nText: {text}") |
| print(f"Prediction: {prediction}") |
| print(f"Confidence: {confidence:.3f}") |
|
|
| |
| if confidence < 0.7: |
| print("Alternative predictions:") |
| for i, (category, prob) in enumerate(top_predictions[:3], 1): |
| print(f" {i}. {category}: {prob:.3f}") |
| print("-" * 60) |
|
|
|
|
| def list_available_models(): |
| """List all available sentiment models""" |
| models = find_local_models() |
|
|
| print("Available Vietnamese Aspect Sentiment Models:") |
| print("=" * 50) |
|
|
| if models['exported']: |
| print("\nExported Models (Project Root):") |
| for model_type, filename in models['exported'].items(): |
| file_size = os.path.getsize(filename) / (1024 * 1024) |
| print(f" {model_type}: {filename} ({file_size:.1f}MB)") |
|
|
| if models['runs']: |
| print("\nRuns Models (Training Directory):") |
| for model_type, filepath in models['runs'].items(): |
| file_size = os.path.getsize(filepath) / (1024 * 1024) |
| print(f" {model_type}: {filepath} ({file_size:.1f}MB)") |
|
|
| if not models['exported'] and not models['runs']: |
| print("No local sentiment models found!") |
| print("Train a model first using: python train.py --export-model") |
|
|
|
|
| def main(): |
| """Main function""" |
| parser = argparse.ArgumentParser( |
| description="Inference with local Pulse Core 1 Vietnamese aspect sentiment models" |
| ) |
| parser.add_argument( |
| "--model-path", |
| type=str, |
| help="Path to specific sentiment model file" |
| ) |
| parser.add_argument( |
| "--text", |
| type=str, |
| help="Vietnamese banking text to analyze (if not provided, enters interactive mode)" |
| ) |
| parser.add_argument( |
| "--test-examples", |
| action="store_true", |
| help="Test with predefined banking examples" |
| ) |
| parser.add_argument( |
| "--list-models", |
| action="store_true", |
| help="List all available local sentiment models" |
| ) |
| parser.add_argument( |
| "--source", |
| type=str, |
| choices=["exported", "runs"], |
| default="runs", |
| help="Model source: exported files or runs directory (default: runs)" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.list_models: |
| list_available_models() |
| return |
|
|
| |
| models = find_local_models() |
|
|
| |
| model_path = None |
| dataset_name = "uts2017_sentiment" |
|
|
| if args.model_path: |
| |
| model_path = args.model_path |
| else: |
| |
| if models[args.source] and 'uts2017_sentiment' in models[args.source]: |
| model_path = models[args.source]['uts2017_sentiment'] |
| print("Auto-selected UTS2017 sentiment model") |
| else: |
| print("No sentiment models found!") |
| list_available_models() |
| return |
|
|
| if not model_path or not os.path.exists(model_path): |
| print(f"Model file not found: {model_path}") |
| list_available_models() |
| return |
|
|
| |
| model = load_model(model_path) |
| if not model: |
| return |
|
|
| |
| if args.text: |
| |
| prediction, confidence, top_predictions = predict_text(model, args.text) |
| if prediction: |
| print(f"\nText: {args.text}") |
| print(f"Prediction: {prediction}") |
| print(f"Confidence: {confidence:.3f}") |
| print("Top 3 predictions:") |
| for i, (category, prob) in enumerate(top_predictions, 1): |
| print(f" {i}. {category}: {prob:.3f}") |
|
|
| elif args.test_examples: |
| |
| test_examples(model, dataset_name) |
|
|
| else: |
| |
| print(f"Loaded sentiment model: {os.path.basename(model_path)}") |
| test_examples(model, dataset_name) |
|
|
| |
| try: |
| response = input("\nEnter interactive mode? (y/n): ").strip().lower() |
| if response in ['y', 'yes']: |
| interactive_mode(model, dataset_name) |
| except KeyboardInterrupt: |
| print("\nExiting...") |
|
|
|
|
| if __name__ == "__main__": |
| main() |