Spaces:
Sleeping
Sleeping
quantumiracle
commited on
Commit
·
aa68c4e
1
Parent(s):
ccfcf8d
fix
Browse files- llava/llava_agent.py +9 -3
llava/llava_agent.py
CHANGED
|
@@ -22,10 +22,16 @@ class LLavaAgent:
|
|
| 22 |
device_map = {'model': torch.device(self.device).index, 'lm_head': torch.device(self.device).index}
|
| 23 |
else:
|
| 24 |
device_map = 'auto'
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
| 28 |
-
|
| 29 |
load_8bit=load_8bit, load_4bit=load_4bit)
|
| 30 |
self.model = model
|
| 31 |
self.image_processor = image_processor
|
|
|
|
| 22 |
device_map = {'model': torch.device(self.device).index, 'lm_head': torch.device(self.device).index}
|
| 23 |
else:
|
| 24 |
device_map = 'auto'
|
| 25 |
+
|
| 26 |
+
# Directly use HF repo if not local
|
| 27 |
+
if os.path.exists(model_path):
|
| 28 |
+
resolved_path = model_path
|
| 29 |
+
else:
|
| 30 |
+
resolved_path = model_path # treat as HF model ID
|
| 31 |
+
|
| 32 |
+
model_name = get_model_name_from_path(resolved_path)
|
| 33 |
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
| 34 |
+
resolved_path, None, model_name, device=self.device, device_map=device_map,
|
| 35 |
load_8bit=load_8bit, load_4bit=load_4bit)
|
| 36 |
self.model = model
|
| 37 |
self.image_processor = image_processor
|