mrm8488 commited on
Commit
56a7af3
·
verified ·
1 Parent(s): 4a6759c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +86 -127
README.md CHANGED
@@ -1,11 +1,15 @@
1
  ---
2
  library_name: transformers
3
- tags: []
 
 
 
 
4
  ---
5
 
6
  # Mistral-7B fine-tuned on AgentInstruct
7
 
8
- [Mistral-7b-v1.0]() fine-tuned on the dataset [AgentInstruct] for "*better* acting as an agent"
9
 
10
 
11
 
@@ -53,128 +57,83 @@ AgentInstruct includes 1,866 trajectories from
53
  stands for filtered trajectories.
54
 
55
  ## Training Details
56
-
57
- ### Training Data
58
-
59
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
60
-
61
- [More Information Needed]
62
-
63
- ### Training Procedure
64
-
65
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
66
-
67
- #### Preprocessing [optional]
68
-
69
- [More Information Needed]
70
-
71
-
72
- #### Training Hyperparameters
73
-
74
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
75
-
76
- #### Speeds, Sizes, Times [optional]
77
-
78
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
79
-
80
- [More Information Needed]
81
-
82
- ## Evaluation
83
-
84
- <!-- This section describes the evaluation protocols and provides the results. -->
85
-
86
- ### Testing Data, Factors & Metrics
87
-
88
- #### Testing Data
89
-
90
- <!-- This should link to a Dataset Card if possible. -->
91
-
92
- [More Information Needed]
93
-
94
- #### Factors
95
-
96
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
97
-
98
- [More Information Needed]
99
-
100
- #### Metrics
101
-
102
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
103
-
104
- [More Information Needed]
105
-
106
- ### Results
107
-
108
- [More Information Needed]
109
-
110
- #### Summary
111
-
112
-
113
-
114
- ## Model Examination [optional]
115
-
116
- <!-- Relevant interpretability work for the model goes here -->
117
-
118
- [More Information Needed]
119
-
120
- ## Environmental Impact
121
-
122
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
123
-
124
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
125
-
126
- - **Hardware Type:** [More Information Needed]
127
- - **Hours used:** [More Information Needed]
128
- - **Cloud Provider:** [More Information Needed]
129
- - **Compute Region:** [More Information Needed]
130
- - **Carbon Emitted:** [More Information Needed]
131
-
132
- ## Technical Specifications [optional]
133
-
134
- ### Model Architecture and Objective
135
-
136
- [More Information Needed]
137
-
138
- ### Compute Infrastructure
139
-
140
- [More Information Needed]
141
-
142
- #### Hardware
143
-
144
- [More Information Needed]
145
-
146
- #### Software
147
-
148
- [More Information Needed]
149
-
150
- ## Citation [optional]
151
-
152
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
153
-
154
- **BibTeX:**
155
-
156
- [More Information Needed]
157
-
158
- **APA:**
159
-
160
- [More Information Needed]
161
-
162
- ## Glossary [optional]
163
-
164
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
165
-
166
- [More Information Needed]
167
-
168
- ## More Information [optional]
169
-
170
- [More Information Needed]
171
-
172
- ## Model Card Authors [optional]
173
-
174
- [More Information Needed]
175
-
176
- ## Model Card Contact
177
-
178
- [More Information Needed]
179
-
180
-
 
1
  ---
2
  library_name: transformers
3
+ license: apache-2.0
4
+ datasets:
5
+ - THUDM/AgentInstruct
6
+ language:
7
+ - en
8
  ---
9
 
10
  # Mistral-7B fine-tuned on AgentInstruct
11
 
12
+ [Mistral-7b-v1.0]() fine-tuned on the dataset [AgentInstruct](https://huggingface.co/datasets/THUDM/AgentInstruct) for "*better* acting as an agent"
13
 
14
 
15
 
 
57
  stands for filtered trajectories.
58
 
59
  ## Training Details
60
+ TBD
61
+
62
+ ## Example of usage
63
+
64
+ ```py
65
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
68
+ model = AutoModelForCausalLM.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
69
+
70
+ class MyStoppingCriteria(StoppingCriteria):
71
+ def __init__(self, target_sequence, prompt):
72
+ self.target_sequence = target_sequence
73
+ self.prompt=prompt
74
+
75
+ def __call__(self, input_ids, scores, **kwargs):
76
+ # Get the generated text as a string
77
+ generated_text = tokenizer.decode(input_ids[0])
78
+ generated_text = generated_text.replace(self.prompt,'')
79
+ # Check if the target sequence appears in the generated text
80
+ if self.target_sequence in generated_text:
81
+ return True # Stop generation
82
+
83
+ return False # Continue generation
84
+
85
+ def __len__(self):
86
+ return 1
87
+
88
+ def __iter__(self):
89
+ yield self
90
+
91
+ def generate(
92
+ context,
93
+ max_new_tokens=256,
94
+ min_new_tokens=64,
95
+ temperature=0.3,
96
+ top_p=0.75,
97
+ top_k=40,
98
+ do_sample=False,
99
+ num_beams=2,
100
+ **kwargs,
101
+ ):
102
+ prompt = context
103
+ #print(prompt)
104
+ inputs = tokenizer(prompt, return_tensors="pt")
105
+ input_ids = inputs["input_ids"].to("cuda")
106
+ attention_mask = inputs["attention_mask"].to("cuda")
107
+ generation_config = GenerationConfig(
108
+ temperature=temperature,
109
+ top_p=top_p,
110
+ top_k=top_k,
111
+ do_sample=do_sample,
112
+ num_beams=num_beams,
113
+ **kwargs,
114
+ )
115
+ with torch.no_grad():
116
+ generation_output = model.generate(
117
+ input_ids=input_ids,
118
+ attention_mask=attention_mask,
119
+ #generation_config=generation_config,
120
+ do_sample=True,
121
+ return_dict_in_generate=True,
122
+ output_scores=True,
123
+ max_new_tokens=max_new_tokens,
124
+ min_new_tokens=min_new_tokens,
125
+ early_stopping=False,
126
+ use_cache=True,
127
+ stopping_criteria=MyStoppingCriteria("### human:", prompt)
128
+ )
129
+ s = generation_output.sequences[0]
130
+ output = tokenizer.decode(s)
131
+ return output
132
+
133
+ human = """### human: Among the reference ID of under 10 who got response by marketing department, compare their education status.
134
+ There are 2 tables involved with this task. The name of the 1st table is Customers, and the headers of this table are ID,SEX,MARITAL_STATUS,GEOID,EDUCATIONNUM,OCCUPATION,age. The name of the 2nd table is Mailings1_2, and the headers of this table are REFID,REF_DATE,RESPONSE."""
135
+ context = context + '\n' + human
136
+
137
+ solution = generate(context)
138
+ print(solution)
139
+ ```