ayush2917 commited on
Commit
9627370
·
verified ·
1 Parent(s): 6a9e10a

Create scripts/main.py

Browse files
Files changed (1) hide show
  1. scripts/main.py +40 -0
scripts/main.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # scripts/main.py
2
+ from src.preprocessing import preprocess_text
3
+ from src.utils import load_model_and_tokenizer
4
+ from src.feature_engineering import tokenize_texts
5
+ import torch
6
+ import logging
7
+ from src.config import MODEL_PATH
8
+
9
+ def setup_logging():
10
+ logging.basicConfig(filename="logs/app.log", level=logging.INFO,
11
+ format="%(asctime)s - %(levelname)s - %(message)s")
12
+
13
+ def predict_text(text, model, tokenizer):
14
+ """Predict category for input text."""
15
+ setup_logging()
16
+ cleaned_text = preprocess_text(text)
17
+ encodings = tokenizer([cleaned_text], truncation=True, padding=True, max_length=128, return_tensors="pt")
18
+ with torch.no_grad():
19
+ outputs = model(**encodings)
20
+ logits = outputs.logits
21
+ pred_label = torch.argmax(logits, dim=1).item()
22
+ label_map = {0: "Electronics", 1: "Household", 2: "Books", 3: "Clothing & Accessories"}
23
+ logging.info(f"Predicted {label_map[pred_label]} for text: {text[:50]}...")
24
+ return label_map[pred_label]
25
+
26
+ def main():
27
+ model, tokenizer = load_model_and_tokenizer(MODEL_PATH)
28
+ while True:
29
+ text = input("Enter text (or 'quit' to exit): ")
30
+ if text.lower() == "quit":
31
+ break
32
+ try:
33
+ prediction = predict_text(text, model, tokenizer)
34
+ print(f"Predicted Category: {prediction}")
35
+ except Exception as e:
36
+ logging.error(f"Error predicting: {e}")
37
+ print("Error processing input. Try again.")
38
+
39
+ if __name__ == "__main__":
40
+ main()