Delta0723 commited on
Commit
55d962c
·
verified ·
1 Parent(s): 91064f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from typing import Optional, List
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
6
  from peft import PeftModel
7
  import torch
8
  import os
@@ -39,12 +40,15 @@ try:
39
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
40
  tokenizer.pad_token = tokenizer.eos_token
41
 
42
- base_model = AutoModelForCausalLM.from_pretrained(
43
- BASE_MODEL,
44
- torch_dtype=torch.float16,
45
- device_map="auto"
46
- )
47
 
 
 
 
 
 
 
 
48
  model = PeftModel.from_pretrained(base_model, LORA_MODEL)
49
  model.eval()
50
 
 
3
  from pydantic import BaseModel
4
  from typing import Optional, List
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from transformers import BitsAndBytesConfig
7
  from peft import PeftModel
8
  import torch
9
  import os
 
40
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
41
  tokenizer.pad_token = tokenizer.eos_token
42
 
43
+ quant_config = BitsAndBytesConfig(load_in_4bit=True)
 
 
 
 
44
 
45
+ base_model = AutoModelForCausalLM.from_pretrained(
46
+ BASE_MODEL,
47
+ device_map="auto",
48
+ trust_remote_code=True,
49
+ offload_folder="offload",
50
+ quantization_config=quant_config
51
+
52
  model = PeftModel.from_pretrained(base_model, LORA_MODEL)
53
  model.eval()
54