xiao-fei commited on
Commit
807defb
·
1 Parent(s): 11fc986

fix minor problems in inference example

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -47,7 +47,7 @@ model = AutoModelForCausalLM.from_pretrained(
47
  pretrained_model_name_or_path="xiao-fei/Prot2Text-V2-11B-Instruct-hf",
48
  trust_remote_code=True,
49
  torch_dtype=torch.bfloat16,
50
- device_map="auto"
51
  )
52
 
53
  esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
@@ -81,16 +81,16 @@ tokenized_prompt = llama_tokenizer.apply_chat_template(
81
  return_dict=True
82
  )
83
  tokenized_sequence = esm_tokenizer(
84
- ex_seq,
85
  return_tensors="pt"
86
  )
87
 
88
  model.eval()
89
  generated = model.generate(
90
- inputs=tokenized_prompt["input_ids"].to(model.device()),
91
- attention_mask=tokenized_prompt["attention_mask"].to(model.device()),
92
- protein_input_ids=tokenized_sequence["input_ids"].to(model.device()),
93
- protein_attention_mask=tokenized_sequence["attention_mask"].to(model.device()),
94
  max_new_tokens=1024,
95
  eos_token_id=128009,
96
  pad_token_id=128002,
 
47
  pretrained_model_name_or_path="xiao-fei/Prot2Text-V2-11B-Instruct-hf",
48
  trust_remote_code=True,
49
  torch_dtype=torch.bfloat16,
50
+ device_map="cuda"
51
  )
52
 
53
  esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
 
81
  return_dict=True
82
  )
83
  tokenized_sequence = esm_tokenizer(
84
+ example_sequence,
85
  return_tensors="pt"
86
  )
87
 
88
  model.eval()
89
  generated = model.generate(
90
+ inputs=tokenized_prompt["input_ids"].to(model.device),
91
+ attention_mask=tokenized_prompt["attention_mask"].to(model.device),
92
+ protein_input_ids=tokenized_sequence["input_ids"].to(model.device),
93
+ protein_attention_mask=tokenized_sequence["attention_mask"].to(model.device),
94
  max_new_tokens=1024,
95
  eos_token_id=128009,
96
  pad_token_id=128002,