Update app.py
Browse files
app.py
CHANGED
|
@@ -42,7 +42,7 @@ def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_d
|
|
| 42 |
except Exception as e:
|
| 43 |
return f"Error de autenticaci贸n: {str(e)}"
|
| 44 |
|
| 45 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
num_workers = multiprocessing.cpu_count()
|
| 47 |
|
| 48 |
if not hasattr(torch, 'xla'):
|
|
@@ -166,7 +166,7 @@ def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_d
|
|
| 166 |
|
| 167 |
progress(0.4, desc="Cargando Modelo...")
|
| 168 |
try:
|
| 169 |
-
original_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
| 170 |
except Exception as e:
|
| 171 |
return f"Error cargando modelo: {str(e)}"
|
| 172 |
|
|
@@ -216,7 +216,7 @@ def run_training(hf_token, model_name, new_repo_name, lora_r, lora_alpha, lora_d
|
|
| 216 |
trainer.save_model(output_dir)
|
| 217 |
|
| 218 |
progress(0.9, desc="Fusionando...")
|
| 219 |
-
ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False
|
| 220 |
|
| 221 |
final_path = "/content/merged_model"
|
| 222 |
ft.save_pretrained(final_path, safe_serialization=True)
|
|
|
|
| 42 |
except Exception as e:
|
| 43 |
return f"Error de autenticaci贸n: {str(e)}"
|
| 44 |
|
| 45 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 46 |
num_workers = multiprocessing.cpu_count()
|
| 47 |
|
| 48 |
if not hasattr(torch, 'xla'):
|
|
|
|
| 166 |
|
| 167 |
progress(0.4, desc="Cargando Modelo...")
|
| 168 |
try:
|
| 169 |
+
original_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
| 170 |
except Exception as e:
|
| 171 |
return f"Error cargando modelo: {str(e)}"
|
| 172 |
|
|
|
|
| 216 |
trainer.save_model(output_dir)
|
| 217 |
|
| 218 |
progress(0.9, desc="Fusionando...")
|
| 219 |
+
ft = PeftModel.from_pretrained(original_model, output_dir, torch_dtype=torch.float32, is_trainable=False).merge_and_unload()
|
| 220 |
|
| 221 |
final_path = "/content/merged_model"
|
| 222 |
ft.save_pretrained(final_path, safe_serialization=True)
|