jfinance-title2return-v1 / predict_ret_next.py
Migaku
initial model card & files
d161ef2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
predict_ret_next.py
-------------------
使い方:
python predict_ret_next.py "業績予想の上方修正に関するお知らせ"
オプション:
--model パス (default: model.joblib と同じフォルダ)
--embed Sentence-Transformers 名 (default: paraphrase-multilingual-MiniLM-L12-v2)
"""
import argparse, joblib, os
from sentence_transformers import SentenceTransformer
def load_model(model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"model not found: {model_path}")
return joblib.load(model_path)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("title", help="開示タイトル(日本語 or 英語)")
ap.add_argument("--model", default="model.joblib",
help="joblib file path (default: ./model.joblib)")
ap.add_argument("--embed", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
help="embedding model name or path")
args = ap.parse_args()
print("▶ loading models …")
reg = load_model(args.model)
embedder = SentenceTransformer(args.embed,
device="cuda" if embedder_gpu() else "cpu")
vec = embedder.encode([args.title])
pred = reg.predict(vec)[0]
print(f"\n予測翌営業日リターン: {pred:.2f} %")
def embedder_gpu():
try:
import torch
if torch.cuda.is_available():
maj, min = torch.cuda.get_device_capability()
return (maj * 10 + min) <= 90 # sm_120 以上は未対応 ⇒ CPU
except ImportError:
pass
return False
if __name__ == "__main__":
main()