qwen to device
Browse files- create_app.py +2 -0
create_app.py
CHANGED
|
@@ -21,6 +21,7 @@ def load_models():
|
|
| 21 |
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
|
| 22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
if not MODELS_LOADED:
|
|
|
|
| 24 |
LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
|
| 25 |
config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
|
| 26 |
LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
|
|
@@ -31,6 +32,7 @@ def load_models():
|
|
| 31 |
QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
|
| 32 |
QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
|
| 33 |
QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).half()
|
|
|
|
| 34 |
MODELS_LOADED = True
|
| 35 |
|
| 36 |
def create_app():
|
|
|
|
| 21 |
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
|
| 22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
if not MODELS_LOADED:
|
| 24 |
+
print("DEVICE", device)
|
| 25 |
LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
|
| 26 |
config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
|
| 27 |
LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
|
|
|
|
| 32 |
QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
|
| 33 |
QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
|
| 34 |
QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).half()
|
| 35 |
+
QWEN_MODEL = QWEN_MODEL.to(device)
|
| 36 |
MODELS_LOADED = True
|
| 37 |
|
| 38 |
def create_app():
|