import torch import torch.nn as nn from janome.tokenizer import Tokenizer import argparse # ===================== # Settings # ===================== MAX_LEN = 20 EMBED_SIZE = 64 MODEL_PATH = "japanese_sentiment_model.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ===================== # Tokenizer # ===================== tokenizer = Tokenizer() def tokenize(text): return [token.surface for token in tokenizer.tokenize(text)] # ===================== # Model # ===================== class SentimentModel(nn.Module): def __init__(self, vocab_size): super().__init__() self.embedding = nn.Embedding(vocab_size, EMBED_SIZE) self.fc = nn.Sequential( nn.Linear(EMBED_SIZE, 32), nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid(), ) def forward(self, x): x = self.embedding(x) x = x.mean(dim=1) x = self.fc(x) return x.squeeze() # ===================== # Load model + vocab # ===================== checkpoint = torch.load(MODEL_PATH, map_location=device) vocab = checkpoint["vocab"] model = SentimentModel(len(vocab)).to(device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print("Model loaded successfully.") def encode(text): tokens = tokenize(text) ids = [vocab.get(token, vocab[""]) for token in tokens] if len(ids) < MAX_LEN: ids += [vocab[""]] * (MAX_LEN - len(ids)) else: ids = ids[:MAX_LEN] return ids def predict_sentiment(text): x = torch.tensor([encode(text)], dtype=torch.long).to(device) with torch.no_grad(): output = model(x).item() if output > 0.5: return "Positive" else: return "Negative" if __name__ == "__main__": parser = argparse.ArgumentParser( description="Japanese sentiment prediction CLI using a saved PyTorch model." ) parser.add_argument( "text", nargs="*", help="Text to predict. If omitted, use --interactive.", ) parser.add_argument( "-i", "--interactive", action="store_true", help="Interactive mode. Type text repeatedly (type 'exit' to quit).", ) args = parser.parse_args() if args.text: predict_sentiment(" ".join(args.text)) elif args.interactive: while True: text = input("text> ").strip() if text.lower() in {"exit", "quit"}: break if text: predict_sentiment(text) else: parser.print_help()