pulse_core_1 / inference.py
Vu Anh
Update technical report and README with latest SVC model results
0b1c1cf
raw
history blame
8.81 kB
#!/usr/bin/env python3
"""
Inference script for Pulse Core 1 - Vietnamese Aspect Sentiment Analysis.
Loads trained sentiment models from local files and performs predictions.
"""
import argparse
import joblib
import os
import glob
def find_local_models():
"""Find all available local sentiment model files"""
models = {
'exported': {},
'runs': {}
}
# Find exported sentiment models in project root
for filename in os.listdir('.'):
if filename.endswith('.joblib'):
if filename.startswith('uts2017_sentiment_'):
models['exported']['uts2017_sentiment'] = filename
# Find models in runs directory - prioritize SVC models
sentiment_runs = glob.glob('runs/*/models/UTS2017_Bank_AspectSentiment_*.joblib')
if sentiment_runs:
# Sort by modification time (most recent first)
sentiment_runs.sort(key=lambda x: os.path.getmtime(x), reverse=True)
# Prefer SVC models over other types
svc_models = [m for m in sentiment_runs if 'SVC' in m]
if svc_models:
models['runs']['uts2017_sentiment'] = svc_models[0] # Most recent SVC
else:
models['runs']['uts2017_sentiment'] = sentiment_runs[0] # Most recent any model
return models
def load_model(model_path):
"""Load a model from file path"""
try:
print(f"Loading model from: {model_path}")
model = joblib.load(model_path)
print(f"Model loaded successfully. Classes: {len(model.classes_)}")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def predict_text(model, text):
"""Make prediction on a single text"""
try:
probabilities = model.predict_proba([text])[0]
# Get top 3 predictions sorted by probability
top_indices = probabilities.argsort()[-3:][::-1]
top_predictions = []
for idx in top_indices:
category = model.classes_[idx]
prob = probabilities[idx]
top_predictions.append((category, prob))
# The prediction should be the top category
prediction = top_predictions[0][0]
confidence = top_predictions[0][1]
return prediction, confidence, top_predictions
except Exception as e:
print(f"Error making prediction: {e}")
return None, 0, []
def interactive_mode(model, dataset_name):
"""Interactive prediction mode"""
print(f"\n{'='*60}")
print("INTERACTIVE MODE - VIETNAMESE BANKING ASPECT SENTIMENT ANALYSIS")
print(f"{'='*60}")
print("Enter Vietnamese banking text to analyze aspect and sentiment (type 'quit' to exit):")
while True:
try:
user_input = input("\nText: ").strip()
if user_input.lower() in ['quit', 'exit', 'q']:
break
if not user_input:
continue
prediction, confidence, top_predictions = predict_text(model, user_input)
if prediction:
print(f"Predicted category: {prediction}")
print(f"Confidence: {confidence:.3f}")
print("Top 3 predictions:")
for i, (category, prob) in enumerate(top_predictions, 1):
print(f" {i}. {category}: {prob:.3f}")
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
print(f"Error: {e}")
def test_examples(model, dataset_name):
"""Test model with predefined banking examples"""
examples = [
"Tôi muốn mở tài khoản tiết kiệm mới",
"Lãi suất vay mua nhà hiện tại quá cao",
"Làm thế nào để đăng ký internet banking?",
"Chi phí chuyển tiền ra nước ngoài rất đắt",
"Ngân hàng ACB có uy tín không?",
"Tôi cần hỗ trợ về dịch vụ ngân hàng",
"Thẻ tín dụng bị khóa không rõ lý do",
"Dịch vụ chăm sóc khách hàng rất tệ",
"Khuyến mãi tháng này rất hấp dẫn",
"Bảo mật tài khoản có được đảm bảo không?"
]
print("\n" + "="*60)
print("TESTING VIETNAMESE BANKING ASPECT SENTIMENT ANALYSIS")
print("="*60)
for text in examples:
prediction, confidence, top_predictions = predict_text(model, text)
if prediction:
print(f"\nText: {text}")
print(f"Prediction: {prediction}")
print(f"Confidence: {confidence:.3f}")
# Show top 3 if confidence is low
if confidence < 0.7:
print("Alternative predictions:")
for i, (category, prob) in enumerate(top_predictions[:3], 1):
print(f" {i}. {category}: {prob:.3f}")
print("-" * 60)
def list_available_models():
"""List all available sentiment models"""
models = find_local_models()
print("Available Vietnamese Aspect Sentiment Models:")
print("=" * 50)
if models['exported']:
print("\nExported Models (Project Root):")
for model_type, filename in models['exported'].items():
file_size = os.path.getsize(filename) / (1024 * 1024) # MB
print(f" {model_type}: {filename} ({file_size:.1f}MB)")
if models['runs']:
print("\nRuns Models (Training Directory):")
for model_type, filepath in models['runs'].items():
file_size = os.path.getsize(filepath) / (1024 * 1024) # MB
print(f" {model_type}: {filepath} ({file_size:.1f}MB)")
if not models['exported'] and not models['runs']:
print("No local sentiment models found!")
print("Train a model first using: python train.py --export-model")
def main():
"""Main function"""
parser = argparse.ArgumentParser(
description="Inference with local Pulse Core 1 Vietnamese aspect sentiment models"
)
parser.add_argument(
"--model-path",
type=str,
help="Path to specific sentiment model file"
)
parser.add_argument(
"--text",
type=str,
help="Vietnamese banking text to analyze (if not provided, enters interactive mode)"
)
parser.add_argument(
"--test-examples",
action="store_true",
help="Test with predefined banking examples"
)
parser.add_argument(
"--list-models",
action="store_true",
help="List all available local sentiment models"
)
parser.add_argument(
"--source",
type=str,
choices=["exported", "runs"],
default="runs",
help="Model source: exported files or runs directory (default: runs)"
)
args = parser.parse_args()
# List models and exit
if args.list_models:
list_available_models()
return
# Find available models
models = find_local_models()
# Determine model path
model_path = None
dataset_name = "uts2017_sentiment"
if args.model_path:
# Use specified model path
model_path = args.model_path
else:
# Auto-select sentiment model
if models[args.source] and 'uts2017_sentiment' in models[args.source]:
model_path = models[args.source]['uts2017_sentiment']
print("Auto-selected UTS2017 sentiment model")
else:
print("No sentiment models found!")
list_available_models()
return
if not model_path or not os.path.exists(model_path):
print(f"Model file not found: {model_path}")
list_available_models()
return
# Load model
model = load_model(model_path)
if not model:
return
# Process based on arguments
if args.text:
# Single prediction
prediction, confidence, top_predictions = predict_text(model, args.text)
if prediction:
print(f"\nText: {args.text}")
print(f"Prediction: {prediction}")
print(f"Confidence: {confidence:.3f}")
print("Top 3 predictions:")
for i, (category, prob) in enumerate(top_predictions, 1):
print(f" {i}. {category}: {prob:.3f}")
elif args.test_examples:
# Test with examples
test_examples(model, dataset_name)
else:
# Interactive mode
print(f"Loaded sentiment model: {os.path.basename(model_path)}")
test_examples(model, dataset_name)
# Ask if user wants interactive mode
try:
response = input("\nEnter interactive mode? (y/n): ").strip().lower()
if response in ['y', 'yes']:
interactive_mode(model, dataset_name)
except KeyboardInterrupt:
print("\nExiting...")
if __name__ == "__main__":
main()