Spaces:
No application file
No application file
File size: 1,511 Bytes
9627370 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
# scripts/main.py
from src.preprocessing import preprocess_text
from src.utils import load_model_and_tokenizer
from src.feature_engineering import tokenize_texts
import torch
import logging
from src.config import MODEL_PATH
def setup_logging():
logging.basicConfig(filename="logs/app.log", level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s")
def predict_text(text, model, tokenizer):
"""Predict category for input text."""
setup_logging()
cleaned_text = preprocess_text(text)
encodings = tokenizer([cleaned_text], truncation=True, padding=True, max_length=128, return_tensors="pt")
with torch.no_grad():
outputs = model(**encodings)
logits = outputs.logits
pred_label = torch.argmax(logits, dim=1).item()
label_map = {0: "Electronics", 1: "Household", 2: "Books", 3: "Clothing & Accessories"}
logging.info(f"Predicted {label_map[pred_label]} for text: {text[:50]}...")
return label_map[pred_label]
def main():
model, tokenizer = load_model_and_tokenizer(MODEL_PATH)
while True:
text = input("Enter text (or 'quit' to exit): ")
if text.lower() == "quit":
break
try:
prediction = predict_text(text, model, tokenizer)
print(f"Predicted Category: {prediction}")
except Exception as e:
logging.error(f"Error predicting: {e}")
print("Error processing input. Try again.")
if __name__ == "__main__":
main() |