FarhanAK128 commited on
Commit
1690b97
·
verified ·
1 Parent(s): 91a975f

Update model_class.py

Browse files
Files changed (1) hide show
  1. model_class.py +3 -3
model_class.py CHANGED
@@ -175,7 +175,7 @@ class TicketGPT(
175
  logits = self.out_head(x) #[2,4,50257]
176
  return logits
177
 
178
- def classify_review(text, model, tokenizer, max_length=None, pad_token_id=50256):
179
  lookup = {
180
  0:"Hardware",
181
  1:"HR Support",
@@ -188,7 +188,7 @@ class TicketGPT(
188
  }
189
 
190
  current_device = next(self.parameters()).device
191
- model.eval()
192
 
193
  # Prepare inputs to the model
194
  input_ids = tokenizer.encode(text)
@@ -203,7 +203,7 @@ class TicketGPT(
203
 
204
  # Model inference
205
  with torch.no_grad():
206
- logits = model(input_tensor)[:, -1, :] # Logits of the last output token
207
  predicted_label = torch.argmax(logits, dim=-1).item()
208
 
209
  # Return the classified result
 
175
  logits = self.out_head(x) #[2,4,50257]
176
  return logits
177
 
178
+ def predict(self, text, tokenizer, max_length=None, pad_token_id=50256):
179
  lookup = {
180
  0:"Hardware",
181
  1:"HR Support",
 
188
  }
189
 
190
  current_device = next(self.parameters()).device
191
+ self.eval()
192
 
193
  # Prepare inputs to the model
194
  input_ids = tokenizer.encode(text)
 
203
 
204
  # Model inference
205
  with torch.no_grad():
206
+ logits = self.forward(input_tensor)[:, -1, :] # Logits of the last output token
207
  predicted_label = torch.argmax(logits, dim=-1).item()
208
 
209
  # Return the classified result