mrm8488 commited on
Commit
463b96d
·
verified ·
1 Parent(s): 56a7af3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -59
README.md CHANGED
@@ -63,76 +63,57 @@ TBD
63
 
64
  ```py
65
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
 
66
 
 
67
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
68
- model = AutoModelForCausalLM.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
69
 
70
  class MyStoppingCriteria(StoppingCriteria):
71
- def __init__(self, target_sequence, prompt):
72
- self.target_sequence = target_sequence
73
- self.prompt=prompt
74
-
75
- def __call__(self, input_ids, scores, **kwargs):
76
- # Get the generated text as a string
77
- generated_text = tokenizer.decode(input_ids[0])
78
- generated_text = generated_text.replace(self.prompt,'')
79
- # Check if the target sequence appears in the generated text
80
- if self.target_sequence in generated_text:
81
- return True # Stop generation
82
-
83
- return False # Continue generation
84
-
85
- def __len__(self):
86
- return 1
87
-
88
- def __iter__(self):
89
- yield self
90
-
91
- def generate(
92
- context,
93
- max_new_tokens=256,
94
- min_new_tokens=64,
95
- temperature=0.3,
96
- top_p=0.75,
97
- top_k=40,
98
- do_sample=False,
99
- num_beams=2,
100
- **kwargs,
101
- ):
102
- prompt = context
103
- #print(prompt)
104
- inputs = tokenizer(prompt, return_tensors="pt")
105
  input_ids = inputs["input_ids"].to("cuda")
106
  attention_mask = inputs["attention_mask"].to("cuda")
107
- generation_config = GenerationConfig(
108
- temperature=temperature,
109
- top_p=top_p,
110
- top_k=top_k,
111
- do_sample=do_sample,
112
- num_beams=num_beams,
113
- **kwargs,
114
- )
 
 
 
 
 
 
 
 
115
  with torch.no_grad():
116
- generation_output = model.generate(
117
- input_ids=input_ids,
118
- attention_mask=attention_mask,
119
- #generation_config=generation_config,
120
- do_sample=True,
121
- return_dict_in_generate=True,
122
- output_scores=True,
123
- max_new_tokens=max_new_tokens,
124
- min_new_tokens=min_new_tokens,
125
- early_stopping=False,
126
- use_cache=True,
127
- stopping_criteria=MyStoppingCriteria("### human:", prompt)
128
- )
129
- s = generation_output.sequences[0]
130
- output = tokenizer.decode(s)
131
  return output
132
 
 
 
133
  human = """### human: Among the reference ID of under 10 who got response by marketing department, compare their education status.
134
  There are 2 tables involved with this task. The name of the 1st table is Customers, and the headers of this table are ID,SEX,MARITAL_STATUS,GEOID,EDUCATIONNUM,OCCUPATION,age. The name of the 2nd table is Mailings1_2, and the headers of this table are REFID,REF_DATE,RESPONSE."""
135
- context = context + '\n' + human
136
 
137
  solution = generate(context)
138
  print(solution)
 
63
 
64
  ```py
65
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
66
+ import torch
67
 
68
+ # Load tokenizer and model
69
  tokenizer = AutoTokenizer.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct")
70
+ model = AutoModelForCausalLM.from_pretrained("mrm8488/mistral-7b-ft-AgentInstruct").to("cuda")
71
 
72
  class MyStoppingCriteria(StoppingCriteria):
73
+ def __init__(self, target_sequence, prompt):
74
+ self.target_sequence = target_sequence
75
+ self.prompt = prompt
76
+
77
+ def __call__(self, input_ids, scores, **kwargs):
78
+ # Decode without prompt and check for target sequence
79
+ generated_text = tokenizer.decode(input_ids[0]).replace(self.prompt, '')
80
+ return self.target_sequence in generated_text
81
+
82
+ def __len__(self):
83
+ return 1
84
+
85
+ def generate(context, max_new_tokens=256, min_new_tokens=64, temperature=0.3, top_p=0.75, top_k=40, do_sample=True, num_beams=2):
86
+ # Prepare input data
87
+ inputs = tokenizer(context, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  input_ids = inputs["input_ids"].to("cuda")
89
  attention_mask = inputs["attention_mask"].to("cuda")
90
+
91
+ # Generation settings
92
+ generation_settings = {
93
+ "max_new_tokens": max_new_tokens,
94
+ "min_new_tokens": min_new_tokens,
95
+ "temperature": temperature,
96
+ "top_p": top_p,
97
+ "top_k": top_k,
98
+ "do_sample": do_sample,
99
+ "num_beams": num_beams,
100
+ "early_stopping": False,
101
+ "use_cache": True,
102
+ "stopping_criteria": MyStoppingCriteria("### human:", context)
103
+ }
104
+
105
+ # Generate response
106
  with torch.no_grad():
107
+ generation_output = model.generate(input_ids, attention_mask, **generation_settings)
108
+
109
+ output = tokenizer.decode(generation_output.sequences[0])
 
 
 
 
 
 
 
 
 
 
 
 
110
  return output
111
 
112
+ # Example usage
113
+ context = ""
114
  human = """### human: Among the reference ID of under 10 who got response by marketing department, compare their education status.
115
  There are 2 tables involved with this task. The name of the 1st table is Customers, and the headers of this table are ID,SEX,MARITAL_STATUS,GEOID,EDUCATIONNUM,OCCUPATION,age. The name of the 2nd table is Mailings1_2, and the headers of this table are REFID,REF_DATE,RESPONSE."""
116
+ context = human
117
 
118
  solution = generate(context)
119
  print(solution)