velocity-ai commited on
Commit
4f1bef3
·
verified ·
1 Parent(s): 1043976

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +21 -25
code/inference.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  import json
3
  import torch
4
- import torch.nn as nn
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
@@ -12,22 +11,7 @@ logger = logging.getLogger(__name__)
12
  # Can specify GPU device with:
13
  # CUDA_VISIBLE_DEVICES="1" python script.py
14
 
15
- class PhiForSequenceClassification(nn.Module):
16
- def __init__(self, base_model, num_labels=2):
17
- super().__init__()
18
- self.phi = base_model
19
- # Create classifier with same dtype as base model
20
- dtype = next(base_model.parameters()).dtype
21
- self.classifier = nn.Linear(self.phi.config.hidden_size, num_labels, dtype=dtype)
22
-
23
- def forward(self, **inputs):
24
- outputs = self.phi(**inputs, output_hidden_states=True)
25
- # Use the last hidden state of the last token for classification
26
- last_hidden_state = outputs.hidden_states[-1][:, -1, :]
27
- logits = self.classifier(last_hidden_state)
28
- return type('Outputs', (), {'logits': logits})()
29
-
30
- def model_fn(model_dir, context=None):
31
  """Load the model for inference"""
32
  try:
33
  model_id = os.getenv("HF_MODEL_ID")
@@ -42,16 +26,19 @@ def model_fn(model_dir, context=None):
42
  # Load tokenizer
43
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
 
45
- # Load base model
46
- base_model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
47
  model_id,
 
48
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
49
  trust_remote_code=True
50
  )
51
 
52
- # Create classification model
53
- model = PhiForSequenceClassification(base_model, num_labels=2)
54
-
55
  # Move model to device
56
  model = model.to(device)
57
 
@@ -83,13 +70,22 @@ def predict_fn(data, model_dict):
83
 
84
  logger.info(f"Model is on device: {device}")
85
 
86
- # Parse input
87
  if isinstance(data, str):
88
  input_text = data
89
  elif isinstance(data, dict):
90
- input_text = data.get("inputs", data.get("text", str(data)))
 
 
 
 
 
 
 
 
91
  else:
92
  input_text = str(data)
 
93
  logger.debug(f"Parsed input text: {input_text}")
94
 
95
  # Create tensors directly on target device
 
1
  import os
2
  import json
3
  import torch
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
 
5
  import logging
6
 
7
  logger = logging.getLogger(__name__)
 
11
  # Can specify GPU device with:
12
  # CUDA_VISIBLE_DEVICES="1" python script.py
13
 
14
+ def model_fn(model_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  """Load the model for inference"""
16
  try:
17
  model_id = os.getenv("HF_MODEL_ID")
 
26
  # Load tokenizer
27
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
28
 
29
+ # Load config
30
+ config = AutoConfig.from_pretrained(model_id,
31
+ num_labels=2,
32
+ trust_remote_code=True)
33
+
34
+ # Load model with sequence classification head
35
+ model = AutoModelForSequenceClassification.from_pretrained(
36
  model_id,
37
+ config=config,
38
  torch_dtype=torch.bfloat16 if device.type == 'cuda' else torch.float32,
39
  trust_remote_code=True
40
  )
41
 
 
 
 
42
  # Move model to device
43
  model = model.to(device)
44
 
 
70
 
71
  logger.info(f"Model is on device: {device}")
72
 
73
+ # Parse input and format it like training data
74
  if isinstance(data, str):
75
  input_text = data
76
  elif isinstance(data, dict):
77
+ # Extract address components
78
+ addr1 = data.get('order_address1', data.get('address_line_1', ''))
79
+ addr2 = data.get('order_address2', data.get('address_line_2', ''))
80
+ city = data.get('order_city', data.get('city', ''))
81
+ state = data.get('order_state', data.get('state', ''))
82
+ pincode = str(data.get('order_pincode', data.get('pincode', '')))
83
+
84
+ # Format exactly like training data
85
+ input_text = f"Address_line_1: {addr1} Address_line_2: {addr2} City: {city} State: {state} Pincode: {pincode}"
86
  else:
87
  input_text = str(data)
88
+
89
  logger.debug(f"Parsed input text: {input_text}")
90
 
91
  # Create tensors directly on target device