SFM2001 commited on
Commit
b92108d
·
1 Parent(s): 78b4c1f

qwen to device

Browse files
Files changed (1) hide show
  1. create_app.py +2 -0
create_app.py CHANGED
@@ -21,6 +21,7 @@ def load_models():
21
  global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  if not MODELS_LOADED:
 
24
  LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
25
  config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
26
  LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
@@ -31,6 +32,7 @@ def load_models():
31
  QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
32
  QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
33
  QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).half()
 
34
  MODELS_LOADED = True
35
 
36
  def create_app():
 
21
  global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  if not MODELS_LOADED:
24
+ print("DEVICE", device)
25
  LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
26
  config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
27
  LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
 
32
  QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
33
  QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
34
  QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).half()
35
+ QWEN_MODEL = QWEN_MODEL.to(device)
36
  MODELS_LOADED = True
37
 
38
  def create_app():