Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from janome.tokenizer import Tokenizer | |
| import argparse | |
| # ===================== | |
| # Settings | |
| # ===================== | |
| MAX_LEN = 20 | |
| EMBED_SIZE = 64 | |
| MODEL_PATH = "japanese_sentiment_model.pth" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ===================== | |
| # Tokenizer | |
| # ===================== | |
| tokenizer = Tokenizer() | |
| def tokenize(text): | |
| return [token.surface for token in tokenizer.tokenize(text)] | |
| # ===================== | |
| # Model | |
| # ===================== | |
| class SentimentModel(nn.Module): | |
| def __init__(self, vocab_size): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, EMBED_SIZE) | |
| self.fc = nn.Sequential( | |
| nn.Linear(EMBED_SIZE, 32), | |
| nn.ReLU(), | |
| nn.Linear(32, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| x = x.mean(dim=1) | |
| x = self.fc(x) | |
| return x.squeeze() | |
| # ===================== | |
| # Load model + vocab | |
| # ===================== | |
| checkpoint = torch.load(MODEL_PATH, map_location=device) | |
| vocab = checkpoint["vocab"] | |
| model = SentimentModel(len(vocab)).to(device) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| def encode(text): | |
| tokens = tokenize(text) | |
| ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens] | |
| if len(ids) < MAX_LEN: | |
| ids += [vocab["<PAD>"]] * (MAX_LEN - len(ids)) | |
| else: | |
| ids = ids[:MAX_LEN] | |
| return ids | |
| def predict_sentiment(text): | |
| x = torch.tensor([encode(text)], dtype=torch.long).to(device) | |
| with torch.no_grad(): | |
| output = model(x).item() | |
| if output > 0.5: | |
| return "Positive" | |
| else: | |
| return "Negative" | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Japanese sentiment prediction CLI using a saved PyTorch model." | |
| ) | |
| parser.add_argument( | |
| "text", | |
| nargs="*", | |
| help="Text to predict. If omitted, use --interactive.", | |
| ) | |
| parser.add_argument( | |
| "-i", | |
| "--interactive", | |
| action="store_true", | |
| help="Interactive mode. Type text repeatedly (type 'exit' to quit).", | |
| ) | |
| args = parser.parse_args() | |
| if args.text: | |
| predict_sentiment(" ".join(args.text)) | |
| elif args.interactive: | |
| while True: | |
| text = input("text> ").strip() | |
| if text.lower() in {"exit", "quit"}: | |
| break | |
| if text: | |
| predict_sentiment(text) | |
| else: | |
| parser.print_help() | |