fixed 16Gi error
Browse files- model_manager.py +42 -10
model_manager.py
CHANGED
|
@@ -22,27 +22,45 @@ def ensure_model_loaded():
|
|
| 22 |
hf_token = os.getenv("HF_TOKEN")
|
| 23 |
|
| 24 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if hf_token:
|
|
|
|
| 26 |
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 27 |
model_id,
|
| 28 |
-
|
| 29 |
-
torch_dtype=torch.float32,
|
| 30 |
-
device_map="auto",
|
| 31 |
-
low_cpu_mem_usage=True
|
| 32 |
)
|
| 33 |
style_processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
|
| 34 |
else:
|
| 35 |
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 36 |
model_id,
|
| 37 |
-
|
| 38 |
-
device_map="auto",
|
| 39 |
-
low_cpu_mem_usage=True
|
| 40 |
)
|
| 41 |
style_processor = AutoProcessor.from_pretrained(model_id)
|
| 42 |
|
| 43 |
-
print(f"Loaded {model_id}")
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Error loading model: {e}")
|
|
|
|
|
|
|
| 46 |
raise
|
| 47 |
|
| 48 |
def generate_chat_response(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> str:
|
|
@@ -80,7 +98,14 @@ def generate_chat_response(prompt: str, max_length: int = 512, temperature: floa
|
|
| 80 |
return_tensors="pt",
|
| 81 |
)
|
| 82 |
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
temperature = max(0.1, min(1.5, temperature))
|
| 86 |
|
|
@@ -163,7 +188,14 @@ async def generate_chat_response_streaming(prompt: str, max_length: int = 512, t
|
|
| 163 |
return_tensors="pt",
|
| 164 |
)
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
temperature = max(0.1, min(1.5, temperature))
|
| 169 |
|
|
|
|
| 22 |
hf_token = os.getenv("HF_TOKEN")
|
| 23 |
|
| 24 |
try:
|
| 25 |
+
if torch.cuda.is_available():
|
| 26 |
+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 27 |
+
max_memory = {0: "14GiB", "cpu": "2GiB"}
|
| 28 |
+
offload_folder = None
|
| 29 |
+
else:
|
| 30 |
+
dtype = torch.float16
|
| 31 |
+
max_memory = {"cpu": "14GiB"}
|
| 32 |
+
offload_folder = "/tmp/model_offload"
|
| 33 |
+
os.makedirs(offload_folder, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
load_kwargs = {
|
| 36 |
+
"torch_dtype": dtype,
|
| 37 |
+
"device_map": "auto",
|
| 38 |
+
"low_cpu_mem_usage": True,
|
| 39 |
+
"max_memory": max_memory,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
if offload_folder:
|
| 43 |
+
load_kwargs["offload_folder"] = offload_folder
|
| 44 |
+
|
| 45 |
if hf_token:
|
| 46 |
+
load_kwargs["token"] = hf_token
|
| 47 |
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 48 |
model_id,
|
| 49 |
+
**load_kwargs
|
|
|
|
|
|
|
|
|
|
| 50 |
)
|
| 51 |
style_processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
|
| 52 |
else:
|
| 53 |
style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 54 |
model_id,
|
| 55 |
+
**load_kwargs
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
style_processor = AutoProcessor.from_pretrained(model_id)
|
| 58 |
|
| 59 |
+
print(f"Loaded {model_id} with dtype={dtype}, device_map=auto")
|
| 60 |
except Exception as e:
|
| 61 |
print(f"Error loading model: {e}")
|
| 62 |
+
import traceback
|
| 63 |
+
traceback.print_exc()
|
| 64 |
raise
|
| 65 |
|
| 66 |
def generate_chat_response(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> str:
|
|
|
|
| 98 |
return_tensors="pt",
|
| 99 |
)
|
| 100 |
|
| 101 |
+
if hasattr(style_model, 'device'):
|
| 102 |
+
device = style_model.device
|
| 103 |
+
elif hasattr(style_model, 'hf_device_map'):
|
| 104 |
+
device = next(iter(style_model.hf_device_map.values())) if style_model.hf_device_map else torch.device("cpu")
|
| 105 |
+
else:
|
| 106 |
+
device = torch.device("cpu")
|
| 107 |
+
|
| 108 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 109 |
|
| 110 |
temperature = max(0.1, min(1.5, temperature))
|
| 111 |
|
|
|
|
| 188 |
return_tensors="pt",
|
| 189 |
)
|
| 190 |
|
| 191 |
+
if hasattr(style_model, 'device'):
|
| 192 |
+
device = style_model.device
|
| 193 |
+
elif hasattr(style_model, 'hf_device_map'):
|
| 194 |
+
device = next(iter(style_model.hf_device_map.values())) if style_model.hf_device_map else torch.device("cpu")
|
| 195 |
+
else:
|
| 196 |
+
device = torch.device("cpu")
|
| 197 |
+
|
| 198 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 199 |
|
| 200 |
temperature = max(0.1, min(1.5, temperature))
|
| 201 |
|