humair025 commited on
Commit
8f5d24b
·
verified ·
1 Parent(s): 44cb2fc

Update soprano/backends/transformers.py

Browse files
Files changed (1) hide show
  1. soprano/backends/transformers.py +9 -3
soprano/backends/transformers.py CHANGED
@@ -9,9 +9,12 @@ class TransformersModel(BaseModel):
9
  **kwargs):
10
  self.device = device
11
 
 
 
 
12
  self.model = AutoModelForCausalLM.from_pretrained(
13
  'ekwek/Soprano-80M',
14
- torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32,
15
  device_map=device
16
  )
17
  self.tokenizer = AutoTokenizer.from_pretrained('ekwek/Soprano-80M')
@@ -43,6 +46,7 @@ class TransformersModel(BaseModel):
43
  return_dict_in_generate=True,
44
  output_hidden_states=True,
45
  )
 
46
  res = []
47
  eos_token_id = self.model.config.eos_token_id
48
  for i in range(len(prompts)):
@@ -51,7 +55,9 @@ class TransformersModel(BaseModel):
51
  num_output_tokens = len(outputs.hidden_states)
52
  for j in range(num_output_tokens):
53
  token = seq[j + seq.size(0) - num_output_tokens]
54
- if token != eos_token_id: hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
 
 
55
  last_hidden_state = torch.stack(hidden_states).squeeze()
56
  finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
57
  res.append({
@@ -65,4 +71,4 @@ class TransformersModel(BaseModel):
65
  top_p=0.95,
66
  temperature=0.3,
67
  repetition_penalty=1.2):
68
- raise NotImplementedError("transformers backend does not currently support streaming, please consider using lmdeploy backend instead.")
 
9
  **kwargs):
10
  self.device = device
11
 
12
+ # Set appropriate dtype based on device
13
+ dtype = torch.bfloat16 if device == 'cuda' else torch.float32
14
+
15
  self.model = AutoModelForCausalLM.from_pretrained(
16
  'ekwek/Soprano-80M',
17
+ torch_dtype=dtype,
18
  device_map=device
19
  )
20
  self.tokenizer = AutoTokenizer.from_pretrained('ekwek/Soprano-80M')
 
46
  return_dict_in_generate=True,
47
  output_hidden_states=True,
48
  )
49
+
50
  res = []
51
  eos_token_id = self.model.config.eos_token_id
52
  for i in range(len(prompts)):
 
55
  num_output_tokens = len(outputs.hidden_states)
56
  for j in range(num_output_tokens):
57
  token = seq[j + seq.size(0) - num_output_tokens]
58
+ if token != eos_token_id:
59
+ hidden_states.append(outputs.hidden_states[j][-1][i, -1, :])
60
+
61
  last_hidden_state = torch.stack(hidden_states).squeeze()
62
  finish_reason = 'stop' if seq[-1].item() == eos_token_id else 'length'
63
  res.append({
 
71
  top_p=0.95,
72
  temperature=0.3,
73
  repetition_penalty=1.2):
74
+ raise NotImplementedError("transformers backend does not currently support streaming, please consider using lmdeploy backend instead.")