Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ model = None
|
|
| 11 |
tokenizer = None
|
| 12 |
generator = None
|
| 13 |
|
| 14 |
-
def load_model(model_name, eight_bit=0, device_map="auto"):
|
| 15 |
global model, tokenizer, generator
|
| 16 |
|
| 17 |
print("Loading "+model_name+"...")
|
|
@@ -24,7 +24,7 @@ def load_model(model_name, eight_bit=0, device_map="auto"):
|
|
| 24 |
print('gpu_count', gpu_count)
|
| 25 |
|
| 26 |
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name)
|
| 27 |
-
model = transformers.
|
| 28 |
model_name,
|
| 29 |
#device_map=device_map,
|
| 30 |
#device_map="auto",
|
|
|
|
| 11 |
tokenizer = None
|
| 12 |
generator = None
|
| 13 |
|
| 14 |
+
def load_model(model_name = "zl111/ChatDoctor", eight_bit=0, device_map="auto"):
|
| 15 |
global model, tokenizer, generator
|
| 16 |
|
| 17 |
print("Loading "+model_name+"...")
|
|
|
|
| 24 |
print('gpu_count', gpu_count)
|
| 25 |
|
| 26 |
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name)
|
| 27 |
+
model = transformers.LlamaForCausalLM.from_pretrained(
|
| 28 |
model_name,
|
| 29 |
#device_map=device_map,
|
| 30 |
#device_map="auto",
|