smrstep commited on
Commit
ac33228
·
verified ·
1 Parent(s): 1b573fa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +29 -1
README.md CHANGED
@@ -46,6 +46,34 @@ As is, CARROT supports routing to the following collection of large language mod
46
  ```python
47
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
48
  import numpy as np
49
-
50
  token = 'YOUR HF TOKEN'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ```
 
46
  ```python
47
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
48
  import numpy as np
 
49
  token = 'YOUR HF TOKEN'
50
+ nput_counter = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-70B", token='')
51
+
52
+ tokenizer = AutoTokenizer.from_pretrained('roberta-base')
53
+
54
+ score_predictor = AutoModelForSequenceClassification.from_pretrained('CARROT-LLM-Routing/Performance',
55
+ problem_type="multi_label_classification",
56
+ num_labels=len(COSTS),
57
+ )
58
+ output_counter = AutoModelForSequenceClassification.from_pretrained('CARROT-LLM-Routing/Cost',
59
+ problem_type="regression",
60
+ num_labels=len(COSTS))
61
+ def CARROT(prompts, mu, input_counter=input_counter, predictors = [score_predictor, output_counter], tokenizer=tokenizer, costs=COSTS):
62
+ tokenized_text = tokenizer(prompts,
63
+ truncation=True,
64
+ padding=True,
65
+ is_split_into_words=False,
66
+ return_tensors='pt')
67
+ input_counter.pad_token = tokenizer.eos_token
68
+ scores = 1/(1+np.exp(-predictors[1](tokenized_text["input_ids"]).logits.detach().numpy()))
69
+ output_tokens = predictors[1](tokenized_text["input_ids"]).logits.detach().numpy()
70
+ input_tokens = [input_counter(prompt, return_tensors="pt")["input_ids"].shape[1] for prompt in prompts]
71
+ input_tokens = np.array(input_tokens).T
72
+ costs = []
73
+ for i, m in enumerate(COSTS.keys()):
74
+ costs.append((input_tokens*COSTS[m][0]/(1000000)+output_tokens[:,i]*COSTS[m][1]/1000).tolist())
75
+ costs = np.array(costs).T
76
+ model_idx = ((1 - mu) * scores - mu * costs*100 ).argmax(axis = 1, keepdims = True)
77
+ called = [id2label[idx[0]] for idx in model_idx]
78
+ return called
79
  ```