#!/usr/bin/env python # -*- coding: utf-8 -*- """ predict_ret_next.py ------------------- 使い方: python predict_ret_next.py "業績予想の上方修正に関するお知らせ" オプション: --model パス (default: model.joblib と同じフォルダ) --embed Sentence-Transformers 名 (default: paraphrase-multilingual-MiniLM-L12-v2) """ import argparse, joblib, os from sentence_transformers import SentenceTransformer def load_model(model_path): if not os.path.exists(model_path): raise FileNotFoundError(f"model not found: {model_path}") return joblib.load(model_path) def main(): ap = argparse.ArgumentParser() ap.add_argument("title", help="開示タイトル(日本語 or 英語)") ap.add_argument("--model", default="model.joblib", help="joblib file path (default: ./model.joblib)") ap.add_argument("--embed", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", help="embedding model name or path") args = ap.parse_args() print("▶ loading models …") reg = load_model(args.model) embedder = SentenceTransformer(args.embed, device="cuda" if embedder_gpu() else "cpu") vec = embedder.encode([args.title]) pred = reg.predict(vec)[0] print(f"\n予測翌営業日リターン: {pred:.2f} %") def embedder_gpu(): try: import torch if torch.cuda.is_available(): maj, min = torch.cuda.get_device_capability() return (maj * 10 + min) <= 90 # sm_120 以上は未対応 ⇒ CPU except ImportError: pass return False if __name__ == "__main__": main()