longdiyao commited on
Commit
a2d7478
·
verified ·
1 Parent(s): 8354251

Update src/models/flan.py

Browse files
Files changed (1) hide show
  1. src/models/flan.py +5 -4
src/models/flan.py CHANGED
@@ -1,4 +1,3 @@
1
- # models/flan.py
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
@@ -9,7 +8,7 @@ class FlanT5:
9
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
10
  model_id,
11
  device_map="auto"
12
- )
13
 
14
  def generate(self, prompt, temperature=0.7):
15
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
@@ -18,6 +17,8 @@ class FlanT5:
18
  do_sample=True,
19
  temperature=temperature,
20
  top_p=0.9,
21
- max_new_tokens=100
 
 
22
  )
23
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
 
 
8
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
9
  model_id,
10
  device_map="auto"
11
+ ).eval()
12
 
13
  def generate(self, prompt, temperature=0.7):
14
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
 
17
  do_sample=True,
18
  temperature=temperature,
19
  top_p=0.9,
20
+ max_new_tokens=64,
21
+ no_repeat_ngram_size=2,
22
+ eos_token_id=self.tokenizer.eos_token_id
23
  )
24
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)