NaolTaye commited on
Commit
223c003
·
1 Parent(s): 9fde312
Files changed (1) hide show
  1. tasks/text.py +12 -0
tasks/text.py CHANGED
@@ -86,7 +86,19 @@ async def evaluate_text(request: TextEvaluationRequest):
86
  # Model inference
87
  model.eval()
88
  predictions = np.array([])
 
 
89
 
 
 
 
 
 
 
 
 
 
 
90
  with torch.no_grad():
91
  print('BEFORE PREDICTION')
92
 
 
86
  # Model inference
87
  model.eval()
88
  predictions = np.array([])
89
+ batch_size = 32
90
+
91
 
92
+ with torch.no_grad():
93
+ for i in range(0, len(test_dataset['quote']), batch_size):
94
+ batch_quotes = test_dataset['quote'][i:i + batch_size]
95
+ print(f'Processing batch {i // batch_size + 1}')
96
+
97
+ # Tokenize the input data for the current batch
98
+ tokenized_inputs = tokenizer(batch_quotes, padding=True, truncation=True, return_tensors='pt').to(device)
99
+
100
+ # Forward pass through the model
101
+ outputs = model(**tokenized_inputs)
102
  with torch.no_grad():
103
  print('BEFORE PREDICTION')
104