Spaces:
No application file
No application file
| # scripts/main.py | |
| from src.preprocessing import preprocess_text | |
| from src.utils import load_model_and_tokenizer | |
| from src.feature_engineering import tokenize_texts | |
| import torch | |
| import logging | |
| from src.config import MODEL_PATH | |
| def setup_logging(): | |
| logging.basicConfig(filename="logs/app.log", level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s") | |
| def predict_text(text, model, tokenizer): | |
| """Predict category for input text.""" | |
| setup_logging() | |
| cleaned_text = preprocess_text(text) | |
| encodings = tokenizer([cleaned_text], truncation=True, padding=True, max_length=128, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**encodings) | |
| logits = outputs.logits | |
| pred_label = torch.argmax(logits, dim=1).item() | |
| label_map = {0: "Electronics", 1: "Household", 2: "Books", 3: "Clothing & Accessories"} | |
| logging.info(f"Predicted {label_map[pred_label]} for text: {text[:50]}...") | |
| return label_map[pred_label] | |
| def main(): | |
| model, tokenizer = load_model_and_tokenizer(MODEL_PATH) | |
| while True: | |
| text = input("Enter text (or 'quit' to exit): ") | |
| if text.lower() == "quit": | |
| break | |
| try: | |
| prediction = predict_text(text, model, tokenizer) | |
| print(f"Predicted Category: {prediction}") | |
| except Exception as e: | |
| logging.error(f"Error predicting: {e}") | |
| print("Error processing input. Try again.") | |
| if __name__ == "__main__": | |
| main() |