# scripts/main.py from src.preprocessing import preprocess_text from src.utils import load_model_and_tokenizer from src.feature_engineering import tokenize_texts import torch import logging from src.config import MODEL_PATH def setup_logging(): logging.basicConfig(filename="logs/app.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") def predict_text(text, model, tokenizer): """Predict category for input text.""" setup_logging() cleaned_text = preprocess_text(text) encodings = tokenizer([cleaned_text], truncation=True, padding=True, max_length=128, return_tensors="pt") with torch.no_grad(): outputs = model(**encodings) logits = outputs.logits pred_label = torch.argmax(logits, dim=1).item() label_map = {0: "Electronics", 1: "Household", 2: "Books", 3: "Clothing & Accessories"} logging.info(f"Predicted {label_map[pred_label]} for text: {text[:50]}...") return label_map[pred_label] def main(): model, tokenizer = load_model_and_tokenizer(MODEL_PATH) while True: text = input("Enter text (or 'quit' to exit): ") if text.lower() == "quit": break try: prediction = predict_text(text, model, tokenizer) print(f"Predicted Category: {prediction}") except Exception as e: logging.error(f"Error predicting: {e}") print("Error processing input. Try again.") if __name__ == "__main__": main()