Ksjsjjdj commited on
Commit
a080efa
verified
1 Parent(s): e2e89c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -42,7 +42,7 @@ def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_d
42
  except Exception as e:
43
  return f"Error de autenticaci贸n: {str(e)}"
44
 
45
- device = "cuda" if torch.cuda.is_available() else "cpu"
46
  num_workers = multiprocessing.cpu_count()
47
 
48
  if not hasattr(torch, 'xla'):
@@ -166,7 +166,7 @@ def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_d
166
 
167
  progress(0.4, desc="Cargando Modelo...")
168
  try:
169
- original_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
170
  except Exception as e:
171
  return f"Error cargando modelo: {str(e)}"
172
 
@@ -216,7 +216,7 @@ def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_d
216
  trainer.save_model(output_dir)
217
 
218
  progress(0.9, desc="Fusionando...")
219
- ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False, device_map={"": device}).merge_and_unload()
220
 
221
  final_path = "/content/merged_model"
222
  ft.save_pretrained(final_path, safe_serialization=True)
 
42
  except Exception as e:
43
  return f"Error de autenticaci贸n: {str(e)}"
44
 
45
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
46
  num_workers = multiprocessing.cpu_count()
47
 
48
  if not hasattr(torch, 'xla'):
 
166
 
167
  progress(0.4, desc="Cargando Modelo...")
168
  try:
169
+ original_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
170
  except Exception as e:
171
  return f"Error cargando modelo: {str(e)}"
172
 
 
216
  trainer.save_model(output_dir)
217
 
218
  progress(0.9, desc="Fusionando...")
219
+ ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False).merge_and_unload()
220
 
221
  final_path = "/content/merged_model"
222
  ft.save_pretrained(final_path, safe_serialization=True)