nexusbert commited on
Commit
a9753f4
·
1 Parent(s): ab2012f

fixed 16Gi error

Browse files
Files changed (1) hide show
  1. model_manager.py +42 -10
model_manager.py CHANGED
@@ -22,27 +22,45 @@ def ensure_model_loaded():
22
  hf_token = os.getenv("HF_TOKEN")
23
 
24
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  if hf_token:
 
26
  style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
  model_id,
28
- token=hf_token,
29
- torch_dtype=torch.float32,
30
- device_map="auto",
31
- low_cpu_mem_usage=True
32
  )
33
  style_processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
34
  else:
35
  style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
36
  model_id,
37
- torch_dtype=torch.float32,
38
- device_map="auto",
39
- low_cpu_mem_usage=True
40
  )
41
  style_processor = AutoProcessor.from_pretrained(model_id)
42
 
43
- print(f"Loaded {model_id}")
44
  except Exception as e:
45
  print(f"Error loading model: {e}")
 
 
46
  raise
47
 
48
  def generate_chat_response(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> str:
@@ -80,7 +98,14 @@ def generate_chat_response(prompt: str, max_length: int = 512, temperature: floa
80
  return_tensors="pt",
81
  )
82
 
83
- inputs = {k: v.to(style_model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
84
 
85
  temperature = max(0.1, min(1.5, temperature))
86
 
@@ -163,7 +188,14 @@ async def generate_chat_response_streaming(prompt: str, max_length: int = 512, t
163
  return_tensors="pt",
164
  )
165
 
166
- inputs = {k: v.to(style_model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
167
 
168
  temperature = max(0.1, min(1.5, temperature))
169
 
 
22
  hf_token = os.getenv("HF_TOKEN")
23
 
24
  try:
25
+ if torch.cuda.is_available():
26
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
27
+ max_memory = {0: "14GiB", "cpu": "2GiB"}
28
+ offload_folder = None
29
+ else:
30
+ dtype = torch.float16
31
+ max_memory = {"cpu": "14GiB"}
32
+ offload_folder = "/tmp/model_offload"
33
+ os.makedirs(offload_folder, exist_ok=True)
34
+
35
+ load_kwargs = {
36
+ "torch_dtype": dtype,
37
+ "device_map": "auto",
38
+ "low_cpu_mem_usage": True,
39
+ "max_memory": max_memory,
40
+ }
41
+
42
+ if offload_folder:
43
+ load_kwargs["offload_folder"] = offload_folder
44
+
45
  if hf_token:
46
+ load_kwargs["token"] = hf_token
47
  style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
48
  model_id,
49
+ **load_kwargs
 
 
 
50
  )
51
  style_processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
52
  else:
53
  style_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
54
  model_id,
55
+ **load_kwargs
 
 
56
  )
57
  style_processor = AutoProcessor.from_pretrained(model_id)
58
 
59
+ print(f"Loaded {model_id} with dtype={dtype}, device_map=auto")
60
  except Exception as e:
61
  print(f"Error loading model: {e}")
62
+ import traceback
63
+ traceback.print_exc()
64
  raise
65
 
66
  def generate_chat_response(prompt: str, max_length: int = 512, temperature: float = 0.7, rag_context: Optional[str] = None, system_override: Optional[str] = None, images: Optional[List[str]] = None) -> str:
 
98
  return_tensors="pt",
99
  )
100
 
101
+ if hasattr(style_model, 'device'):
102
+ device = style_model.device
103
+ elif hasattr(style_model, 'hf_device_map'):
104
+ device = next(iter(style_model.hf_device_map.values())) if style_model.hf_device_map else torch.device("cpu")
105
+ else:
106
+ device = torch.device("cpu")
107
+
108
+ inputs = {k: v.to(device) for k, v in inputs.items()}
109
 
110
  temperature = max(0.1, min(1.5, temperature))
111
 
 
188
  return_tensors="pt",
189
  )
190
 
191
+ if hasattr(style_model, 'device'):
192
+ device = style_model.device
193
+ elif hasattr(style_model, 'hf_device_map'):
194
+ device = next(iter(style_model.hf_device_map.values())) if style_model.hf_device_map else torch.device("cpu")
195
+ else:
196
+ device = torch.device("cpu")
197
+
198
+ inputs = {k: v.to(device) for k, v in inputs.items()}
199
 
200
  temperature = max(0.1, min(1.5, temperature))
201