|
|
|
|
|
|
|
|
""" |
|
|
predict_ret_next.py |
|
|
------------------- |
|
|
使い方: |
|
|
python predict_ret_next.py "業績予想の上方修正に関するお知らせ" |
|
|
|
|
|
オプション: |
|
|
--model パス (default: model.joblib と同じフォルダ) |
|
|
--embed Sentence-Transformers 名 (default: paraphrase-multilingual-MiniLM-L12-v2) |
|
|
""" |
|
|
|
|
|
import argparse, joblib, os |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
def load_model(model_path): |
|
|
if not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"model not found: {model_path}") |
|
|
return joblib.load(model_path) |
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("title", help="開示タイトル(日本語 or 英語)") |
|
|
ap.add_argument("--model", default="model.joblib", |
|
|
help="joblib file path (default: ./model.joblib)") |
|
|
ap.add_argument("--embed", default="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
|
|
help="embedding model name or path") |
|
|
args = ap.parse_args() |
|
|
|
|
|
print("▶ loading models …") |
|
|
reg = load_model(args.model) |
|
|
embedder = SentenceTransformer(args.embed, |
|
|
device="cuda" if embedder_gpu() else "cpu") |
|
|
|
|
|
vec = embedder.encode([args.title]) |
|
|
pred = reg.predict(vec)[0] |
|
|
print(f"\n予測翌営業日リターン: {pred:.2f} %") |
|
|
|
|
|
def embedder_gpu(): |
|
|
try: |
|
|
import torch |
|
|
if torch.cuda.is_available(): |
|
|
maj, min = torch.cuda.get_device_capability() |
|
|
return (maj * 10 + min) <= 90 |
|
|
except ImportError: |
|
|
pass |
|
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|