File size: 1,511 Bytes
9627370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# scripts/main.py
from src.preprocessing import preprocess_text
from src.utils import load_model_and_tokenizer
from src.feature_engineering import tokenize_texts
import torch
import logging
from src.config import MODEL_PATH

def setup_logging():
    logging.basicConfig(filename="logs/app.log", level=logging.INFO, 
                        format="%(asctime)s - %(levelname)s - %(message)s")

def predict_text(text, model, tokenizer):
    """Predict category for input text."""
    setup_logging()
    cleaned_text = preprocess_text(text)
    encodings = tokenizer([cleaned_text], truncation=True, padding=True, max_length=128, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**encodings)
        logits = outputs.logits
        pred_label = torch.argmax(logits, dim=1).item()
    label_map = {0: "Electronics", 1: "Household", 2: "Books", 3: "Clothing & Accessories"}
    logging.info(f"Predicted {label_map[pred_label]} for text: {text[:50]}...")
    return label_map[pred_label]

def main():
    model, tokenizer = load_model_and_tokenizer(MODEL_PATH)
    while True:
        text = input("Enter text (or 'quit' to exit): ")
        if text.lower() == "quit":
            break
        try:
            prediction = predict_text(text, model, tokenizer)
            print(f"Predicted Category: {prediction}")
        except Exception as e:
            logging.error(f"Error predicting: {e}")
            print("Error processing input. Try again.")

if __name__ == "__main__":
    main()