Japanese_sentiment / detect.py
Obb12's picture
Update detect.py
35b16da verified
import torch
import torch.nn as nn
from janome.tokenizer import Tokenizer
import argparse
# =====================
# Settings
# =====================
MAX_LEN = 20
EMBED_SIZE = 64
MODEL_PATH = "japanese_sentiment_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =====================
# Tokenizer
# =====================
tokenizer = Tokenizer()
def tokenize(text):
return [token.surface for token in tokenizer.tokenize(text)]
# =====================
# Model
# =====================
class SentimentModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, EMBED_SIZE)
self.fc = nn.Sequential(
nn.Linear(EMBED_SIZE, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = self.embedding(x)
x = x.mean(dim=1)
x = self.fc(x)
return x.squeeze()
# =====================
# Load model + vocab
# =====================
checkpoint = torch.load(MODEL_PATH, map_location=device)
vocab = checkpoint["vocab"]
model = SentimentModel(len(vocab)).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("Model loaded successfully.")
def encode(text):
tokens = tokenize(text)
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
if len(ids) < MAX_LEN:
ids += [vocab["<PAD>"]] * (MAX_LEN - len(ids))
else:
ids = ids[:MAX_LEN]
return ids
def predict_sentiment(text):
x = torch.tensor([encode(text)], dtype=torch.long).to(device)
with torch.no_grad():
output = model(x).item()
if output > 0.5:
return "Positive"
else:
return "Negative"
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Japanese sentiment prediction CLI using a saved PyTorch model."
)
parser.add_argument(
"text",
nargs="*",
help="Text to predict. If omitted, use --interactive.",
)
parser.add_argument(
"-i",
"--interactive",
action="store_true",
help="Interactive mode. Type text repeatedly (type 'exit' to quit).",
)
args = parser.parse_args()
if args.text:
predict_sentiment(" ".join(args.text))
elif args.interactive:
while True:
text = input("text> ").strip()
if text.lower() in {"exit", "quit"}:
break
if text:
predict_sentiment(text)
else:
parser.print_help()