Update handler.py
Browse files- handler.py +5 -3
handler.py
CHANGED
|
@@ -8,11 +8,12 @@ token = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
|
| 8 |
|
| 9 |
class EndpointHandler:
|
| 10 |
def __init__(self, path=""):
|
| 11 |
-
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
| 12 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
"meta-llama/Llama-2-7b-hf",
|
| 14 |
torch_dtype=torch.float16,
|
| 15 |
-
device_map="auto"
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
lora_config = LoraConfig(
|
|
@@ -28,7 +29,8 @@ class EndpointHandler:
|
|
| 28 |
adapter_path = hf_hub_download(
|
| 29 |
repo_id="vignesh0007/Anime-Gen-Llama-2-7B",
|
| 30 |
filename="adapter_model.safetensors",
|
| 31 |
-
repo_type="model"
|
|
|
|
| 32 |
)
|
| 33 |
lora_state = load_file(adapter_path)
|
| 34 |
self.model.load_state_dict(lora_state, strict=False)
|
|
|
|
| 8 |
|
| 9 |
class EndpointHandler:
|
| 10 |
def __init__(self, path=""):
|
| 11 |
+
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token=token)
|
| 12 |
base_model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
"meta-llama/Llama-2-7b-hf",
|
| 14 |
torch_dtype=torch.float16,
|
| 15 |
+
device_map="auto",
|
| 16 |
+
token=token
|
| 17 |
)
|
| 18 |
|
| 19 |
lora_config = LoraConfig(
|
|
|
|
| 29 |
adapter_path = hf_hub_download(
|
| 30 |
repo_id="vignesh0007/Anime-Gen-Llama-2-7B",
|
| 31 |
filename="adapter_model.safetensors",
|
| 32 |
+
repo_type="model",
|
| 33 |
+
token=token
|
| 34 |
)
|
| 35 |
lora_state = load_file(adapter_path)
|
| 36 |
self.model.load_state_dict(lora_state, strict=False)
|