File size: 2,598 Bytes
c1cfbf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00db099
c1cfbf2
 
 
 
 
35b16da
c1cfbf2
35b16da
c1cfbf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00db099
c1cfbf2
 
 
 
 
 
00db099
c1cfbf2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
from janome.tokenizer import Tokenizer
import argparse

# =====================
# Settings
# =====================

MAX_LEN = 20
EMBED_SIZE = 64
MODEL_PATH = "japanese_sentiment_model.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =====================
# Tokenizer
# =====================

tokenizer = Tokenizer()


def tokenize(text):
    return [token.surface for token in tokenizer.tokenize(text)]


# =====================
# Model
# =====================

class SentimentModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, EMBED_SIZE)
        self.fc = nn.Sequential(
            nn.Linear(EMBED_SIZE, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.embedding(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x.squeeze()


# =====================
# Load model + vocab
# =====================

checkpoint = torch.load(MODEL_PATH, map_location=device)
vocab = checkpoint["vocab"]
model = SentimentModel(len(vocab)).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Model loaded successfully.")


def encode(text):
    tokens = tokenize(text)
    ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
    if len(ids) < MAX_LEN:
        ids += [vocab["<PAD>"]] * (MAX_LEN - len(ids))
    else:
        ids = ids[:MAX_LEN]
    return ids


def predict_sentiment(text):
    x = torch.tensor([encode(text)], dtype=torch.long).to(device)
    with torch.no_grad():
        output = model(x).item()

    if output > 0.5:
        return "Positive"
    else:
        return "Negative" 


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Japanese sentiment prediction CLI using a saved PyTorch model."
    )
    parser.add_argument(
        "text",
        nargs="*",
        help="Text to predict. If omitted, use --interactive.",
    )
    parser.add_argument(
        "-i",
        "--interactive",
        action="store_true",
        help="Interactive mode. Type text repeatedly (type 'exit' to quit).",
    )

    args = parser.parse_args()

    if args.text:
        predict_sentiment(" ".join(args.text))
    elif args.interactive:
        while True:
            text = input("text> ").strip()
            if text.lower() in {"exit", "quit"}:
                break
            if text:
                predict_sentiment(text)
    else:
        parser.print_help()