aeb56 commited on
Commit
e32298d
Β·
1 Parent(s): 1443f5f

Add 8-bit quantization support and switch to L4x4 hardware for availability

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +55 -17
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: docker
7
  pinned: false
8
  license: apache-2.0
9
  app_port: 7860
10
- suggested_hardware: l40sx4
11
  ---
12
 
13
  # πŸ”— LoRA Model Merger
 
7
  pinned: false
8
  license: apache-2.0
9
  app_port: 7860
10
+ suggested_hardware: l4x4
11
  ---
12
 
13
  # πŸ”— LoRA Model Merger
app.py CHANGED
@@ -61,7 +61,7 @@ class ModelMerger:
61
  logger.error(f"Login failed: {str(e)}")
62
  return f"❌ Login failed: {str(e)}"
63
 
64
- def merge_models(self, hf_token, progress=gr.Progress()):
65
  """Merge LoRA adapters with base model"""
66
  try:
67
  # Login to HF
@@ -79,16 +79,27 @@ class ModelMerger:
79
  logger.info("Loading tokenizer...")
80
  self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, trust_remote_code=True)
81
 
82
- # Configure memory allocation for multi-GPU setup (4xL40S = 4x48GB = 192GB)
83
- # Reserve some memory for CUDA overhead and operations
84
  num_gpus = torch.cuda.device_count()
85
  max_memory = {}
 
 
86
  if num_gpus > 0:
87
- # Allocate memory per GPU (leave ~2GB per GPU for overhead)
88
- per_gpu_memory = "46GB" # 48GB - 2GB overhead for L40S
89
  for i in range(num_gpus):
 
 
 
 
90
  max_memory[i] = per_gpu_memory
 
 
91
  logger.info(f"Configured max_memory: {max_memory}")
 
 
 
 
92
  else:
93
  # Fallback for CPU-only (will be slow)
94
  max_memory = {"cpu": "64GB"}
@@ -97,28 +108,48 @@ class ModelMerger:
97
  # Load base model with explicit multi-GPU configuration
98
  progress(0.25, desc="Loading base model (this may take several minutes)...")
99
  logger.info(f"Loading base model: {BASE_MODEL_NAME}")
100
- logger.info(f"Using bfloat16 precision for memory efficiency")
 
 
 
 
 
 
101
 
102
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  self.base_model = AutoModelForCausalLM.from_pretrained(
104
  BASE_MODEL_NAME,
105
- torch_dtype=torch.bfloat16,
106
- device_map="auto",
107
- max_memory=max_memory,
108
- trust_remote_code=True,
109
- low_cpu_mem_usage=True,
110
- offload_folder="/tmp/offload", # Fallback offload directory
111
- offload_state_dict=True, # Offload state dict when loading
112
  )
113
- logger.info("Base model loaded successfully")
114
 
115
  # Log device map to see distribution
116
  if hasattr(self.base_model, 'hf_device_map'):
117
  logger.info(f"Model device map: {self.base_model.hf_device_map}")
118
 
119
  except torch.cuda.OutOfMemoryError as e:
120
- logger.error("Out of memory error! Try with quantization or smaller batch size")
121
- raise Exception(f"GPU Out of Memory: {str(e)}. The 48B model requires ~96GB VRAM in bfloat16. Ensure 4xL40S GPUs are available.")
 
 
 
 
122
 
123
  # Load LoRA configuration
124
  progress(0.50, desc="Loading LoRA adapters...")
@@ -318,12 +349,19 @@ with gr.Blocks(theme=gr.themes.Soft(), title="LoRA Model Merger") as demo:
318
  info="Required for accessing private models or avoiding rate limits"
319
  )
320
 
 
 
 
 
 
 
 
321
  merge_button = gr.Button("πŸš€ Start Merge Process", variant="primary", size="lg")
322
  merge_output = gr.Markdown(label="Merge Status")
323
 
324
  merge_button.click(
325
  fn=merger.merge_models,
326
- inputs=[hf_token_merge],
327
  outputs=merge_output
328
  )
329
 
 
61
  logger.error(f"Login failed: {str(e)}")
62
  return f"❌ Login failed: {str(e)}"
63
 
64
+ def merge_models(self, hf_token, use_8bit=False, progress=gr.Progress()):
65
  """Merge LoRA adapters with base model"""
66
  try:
67
  # Login to HF
 
79
  logger.info("Loading tokenizer...")
80
  self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, trust_remote_code=True)
81
 
82
+ # Configure memory allocation for multi-GPU setup
83
+ # Auto-detect GPU memory and adjust accordingly
84
  num_gpus = torch.cuda.device_count()
85
  max_memory = {}
86
+ total_vram = 0
87
+
88
  if num_gpus > 0:
89
+ # Calculate available memory per GPU
 
90
  for i in range(num_gpus):
91
+ gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
92
+ total_vram += gpu_memory
93
+ # Reserve 2-4GB per GPU for overhead
94
+ per_gpu_memory = f"{int(gpu_memory - 3)}GB"
95
  max_memory[i] = per_gpu_memory
96
+
97
+ logger.info(f"Detected {num_gpus} GPUs with total {total_vram:.1f}GB VRAM")
98
  logger.info(f"Configured max_memory: {max_memory}")
99
+
100
+ # Warn if total VRAM is low
101
+ if total_vram < 90 and not use_8bit:
102
+ logger.warning(f"Only {total_vram:.1f}GB VRAM available. The 48B model needs ~96GB in bfloat16. Consider enabling 8-bit quantization.")
103
  else:
104
  # Fallback for CPU-only (will be slow)
105
  max_memory = {"cpu": "64GB"}
 
108
  # Load base model with explicit multi-GPU configuration
109
  progress(0.25, desc="Loading base model (this may take several minutes)...")
110
  logger.info(f"Loading base model: {BASE_MODEL_NAME}")
111
+
112
+ if use_8bit:
113
+ logger.info(f"Using 8-bit quantization for memory efficiency (~50% memory reduction)")
114
+ precision_desc = "int8"
115
+ else:
116
+ logger.info(f"Using bfloat16 precision for memory efficiency")
117
+ precision_desc = "bfloat16"
118
 
119
  try:
120
+ load_kwargs = {
121
+ "trust_remote_code": True,
122
+ "low_cpu_mem_usage": True,
123
+ "device_map": "auto",
124
+ "max_memory": max_memory,
125
+ "offload_folder": "/tmp/offload",
126
+ "offload_state_dict": True,
127
+ }
128
+
129
+ if use_8bit:
130
+ # Use 8-bit quantization for tighter memory constraints
131
+ load_kwargs["load_in_8bit"] = True
132
+ else:
133
+ # Use bfloat16 for best quality when memory allows
134
+ load_kwargs["torch_dtype"] = torch.bfloat16
135
+
136
  self.base_model = AutoModelForCausalLM.from_pretrained(
137
  BASE_MODEL_NAME,
138
+ **load_kwargs
 
 
 
 
 
 
139
  )
140
+ logger.info(f"Base model loaded successfully in {precision_desc}")
141
 
142
  # Log device map to see distribution
143
  if hasattr(self.base_model, 'hf_device_map'):
144
  logger.info(f"Model device map: {self.base_model.hf_device_map}")
145
 
146
  except torch.cuda.OutOfMemoryError as e:
147
+ logger.error("Out of memory error!")
148
+ error_msg = f"GPU Out of Memory: The 48B model requires ~96GB VRAM in bfloat16 or ~48GB in 8-bit.\n"
149
+ error_msg += f"You have {total_vram:.1f}GB VRAM available.\n"
150
+ if not use_8bit:
151
+ error_msg += "\nοΏ½οΏ½οΏ½ **Try enabling 8-bit quantization** to reduce memory usage by ~50%."
152
+ raise Exception(error_msg)
153
 
154
  # Load LoRA configuration
155
  progress(0.50, desc="Loading LoRA adapters...")
 
349
  info="Required for accessing private models or avoiding rate limits"
350
  )
351
 
352
+ with gr.Row():
353
+ use_8bit_checkbox = gr.Checkbox(
354
+ label="Use 8-bit Quantization",
355
+ value=False,
356
+ info="Enable this if you have limited GPU memory (<96GB total). Reduces memory usage by ~50% with minimal quality loss."
357
+ )
358
+
359
  merge_button = gr.Button("πŸš€ Start Merge Process", variant="primary", size="lg")
360
  merge_output = gr.Markdown(label="Merge Status")
361
 
362
  merge_button.click(
363
  fn=merger.merge_models,
364
+ inputs=[hf_token_merge, use_8bit_checkbox],
365
  outputs=merge_output
366
  )
367