Update model_class.py
Browse files- model_class.py +3 -3
model_class.py
CHANGED
|
@@ -175,7 +175,7 @@ class TicketGPT(
|
|
| 175 |
logits = self.out_head(x) #[2,4,50257]
|
| 176 |
return logits
|
| 177 |
|
| 178 |
-
def
|
| 179 |
lookup = {
|
| 180 |
0:"Hardware",
|
| 181 |
1:"HR Support",
|
|
@@ -188,7 +188,7 @@ class TicketGPT(
|
|
| 188 |
}
|
| 189 |
|
| 190 |
current_device = next(self.parameters()).device
|
| 191 |
-
|
| 192 |
|
| 193 |
# Prepare inputs to the model
|
| 194 |
input_ids = tokenizer.encode(text)
|
|
@@ -203,7 +203,7 @@ class TicketGPT(
|
|
| 203 |
|
| 204 |
# Model inference
|
| 205 |
with torch.no_grad():
|
| 206 |
-
logits =
|
| 207 |
predicted_label = torch.argmax(logits, dim=-1).item()
|
| 208 |
|
| 209 |
# Return the classified result
|
|
|
|
| 175 |
logits = self.out_head(x) #[2,4,50257]
|
| 176 |
return logits
|
| 177 |
|
| 178 |
+
def predict(self, text, tokenizer, max_length=None, pad_token_id=50256):
|
| 179 |
lookup = {
|
| 180 |
0:"Hardware",
|
| 181 |
1:"HR Support",
|
|
|
|
| 188 |
}
|
| 189 |
|
| 190 |
current_device = next(self.parameters()).device
|
| 191 |
+
self.eval()
|
| 192 |
|
| 193 |
# Prepare inputs to the model
|
| 194 |
input_ids = tokenizer.encode(text)
|
|
|
|
| 203 |
|
| 204 |
# Model inference
|
| 205 |
with torch.no_grad():
|
| 206 |
+
logits = self.forward(input_tensor)[:, -1, :] # Logits of the last output token
|
| 207 |
predicted_label = torch.argmax(logits, dim=-1).item()
|
| 208 |
|
| 209 |
# Return the classified result
|