ayush2917 commited on
Commit
2eb5a40
·
verified ·
1 Parent(s): b013e35

Update src/feature_engineering.py

Browse files
Files changed (1) hide show
  1. src/feature_engineering.py +9 -9
src/feature_engineering.py CHANGED
@@ -1,19 +1,19 @@
1
  # src/feature_engineering.py
2
  from transformers import DistilBertTokenizer
3
- import torch
4
- from src.config import MAX_LENGTH
5
  import logging
 
6
 
7
  def setup_logging():
8
- logging.basicConfig(filename="logs/app.log", level=logging.INFO,
9
  format="%(asctime)s - %(levelname)s - %(message)s")
10
 
11
- def tokenize_texts(texts):
12
  """Tokenize texts using DistilBERT tokenizer."""
13
  setup_logging()
14
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
 
15
  logging.info("Tokenizing texts")
16
- encodings = tokenizer(
17
- texts.tolist(), truncation=True, padding=True, max_length=MAX_LENGTH, return_tensors="pt"
18
- )
19
- return encodings
 
1
  # src/feature_engineering.py
2
  from transformers import DistilBertTokenizer
 
 
3
  import logging
4
+ from src.config import MODEL_NAME, MAX_LENGTH, LOG_FILE
5
 
6
  def setup_logging():
7
+ logging.basicConfig(filename=LOG_FILE, level=logging.INFO,
8
  format="%(asctime)s - %(levelname)s - %(message)s")
9
 
10
+ def tokenize_texts(dataset, tokenizer=None):
11
  """Tokenize texts using DistilBERT tokenizer."""
12
  setup_logging()
13
+ if tokenizer is None:
14
+ tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
15
  logging.info("Tokenizing texts")
16
+ def tokenize_function(examples):
17
+ return tokenizer(examples["text"], truncation=True, padding=True, max_length=MAX_LENGTH)
18
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
19
+ return tokenized_dataset, tokenizer