Spaces:
Running
Running
File size: 2,598 Bytes
c1cfbf2 00db099 c1cfbf2 35b16da c1cfbf2 35b16da c1cfbf2 00db099 c1cfbf2 00db099 c1cfbf2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | import torch
import torch.nn as nn
from janome.tokenizer import Tokenizer
import argparse
# =====================
# Settings
# =====================
MAX_LEN = 20
EMBED_SIZE = 64
MODEL_PATH = "japanese_sentiment_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =====================
# Tokenizer
# =====================
tokenizer = Tokenizer()
def tokenize(text):
return [token.surface for token in tokenizer.tokenize(text)]
# =====================
# Model
# =====================
class SentimentModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, EMBED_SIZE)
self.fc = nn.Sequential(
nn.Linear(EMBED_SIZE, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = self.embedding(x)
x = x.mean(dim=1)
x = self.fc(x)
return x.squeeze()
# =====================
# Load model + vocab
# =====================
checkpoint = torch.load(MODEL_PATH, map_location=device)
vocab = checkpoint["vocab"]
model = SentimentModel(len(vocab)).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("Model loaded successfully.")
def encode(text):
tokens = tokenize(text)
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
if len(ids) < MAX_LEN:
ids += [vocab["<PAD>"]] * (MAX_LEN - len(ids))
else:
ids = ids[:MAX_LEN]
return ids
def predict_sentiment(text):
x = torch.tensor([encode(text)], dtype=torch.long).to(device)
with torch.no_grad():
output = model(x).item()
if output > 0.5:
return "Positive"
else:
return "Negative"
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Japanese sentiment prediction CLI using a saved PyTorch model."
)
parser.add_argument(
"text",
nargs="*",
help="Text to predict. If omitted, use --interactive.",
)
parser.add_argument(
"-i",
"--interactive",
action="store_true",
help="Interactive mode. Type text repeatedly (type 'exit' to quit).",
)
args = parser.parse_args()
if args.text:
predict_sentiment(" ".join(args.text))
elif args.interactive:
while True:
text = input("text> ").strip()
if text.lower() in {"exit", "quit"}:
break
if text:
predict_sentiment(text)
else:
parser.print_help()
|