Spaces:
Sleeping
Sleeping
Commit
·
444ccdb
1
Parent(s):
3a7a5c6
update
Browse files- src/attribution/attntrace.py +1 -1
- src/models/Llama.py +7 -6
src/attribution/attntrace.py
CHANGED
|
@@ -41,7 +41,7 @@ class AttnTraceAttribution(Attribution):
|
|
| 41 |
if self.llm.model!=None:
|
| 42 |
self.model = self.llm.model
|
| 43 |
else:
|
| 44 |
-
self.model = self.llm._load_model_if_needed()
|
| 45 |
self.layers = range(len(self.model.model.layers))
|
| 46 |
model = self.model
|
| 47 |
tokenizer = self.tokenizer
|
|
|
|
| 41 |
if self.llm.model!=None:
|
| 42 |
self.model = self.llm.model
|
| 43 |
else:
|
| 44 |
+
self.model = self.llm._load_model_if_needed().to("cuda")
|
| 45 |
self.layers = range(len(self.model.model.layers))
|
| 46 |
model = self.model
|
| 47 |
tokenizer = self.tokenizer
|
src/models/Llama.py
CHANGED
|
@@ -24,17 +24,18 @@ class Llama(Model):
|
|
| 24 |
]
|
| 25 |
|
| 26 |
def _load_model_if_needed(self):
|
| 27 |
-
if self.
|
| 28 |
-
|
| 29 |
self.name,
|
| 30 |
torch_dtype=torch.bfloat16,
|
| 31 |
-
|
| 32 |
-
|
| 33 |
)
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
def query(self, msg, max_tokens=128000):
|
| 37 |
-
model = self._load_model_if_needed()
|
| 38 |
messages = self.messages
|
| 39 |
messages[1]["content"] = msg
|
| 40 |
|
|
|
|
| 24 |
]
|
| 25 |
|
| 26 |
def _load_model_if_needed(self):
|
| 27 |
+
if self._model is None:
|
| 28 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 29 |
self.name,
|
| 30 |
torch_dtype=torch.bfloat16,
|
| 31 |
+
token=self.hf_token,
|
| 32 |
+
device_map="auto", # or omit entirely to default to CPU
|
| 33 |
)
|
| 34 |
+
self._model = model
|
| 35 |
+
return self._model
|
| 36 |
|
| 37 |
def query(self, msg, max_tokens=128000):
|
| 38 |
+
model = self._load_model_if_needed().to("cuda")
|
| 39 |
messages = self.messages
|
| 40 |
messages[1]["content"] = msg
|
| 41 |
|