Ignaciohhhhggfgjfrffd commited on
Commit
724deb0
·
verified ·
1 Parent(s): bab72e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -2
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
 
2
  os.system("pip install -U gradio")
3
  os.system("pip install -U bitsandbytes diffusers torchaudio torchvision torch transformers peft accelerate trl datasets")
4
- os.system("pip install spaces-0.1.0-py3-none-any.whl")
5
  os.system("pip install gradio_huggingfacehub_search packaging torchao llmcompressor")
6
 
7
  import io
@@ -81,6 +82,7 @@ from llmcompressor.modifiers.awq import AWQModifier
81
  logger = logging.getLogger(__name__)
82
  torch_dtype_auto = torch.float32
83
 
 
84
  def _sanitize_model_name_for_yaml(model_name):
85
  name = model_name.split('/')[-1] if '/' in model_name else model_name
86
  sanitized = re.sub(r'[^a-zA-Z0-9\-_\.]', '-', name)
@@ -173,11 +175,13 @@ _tox_pipe_singleton = None
173
 
174
  @spaces.GPU
175
  class DebiasingSFTTrainer(SFTTrainer):
 
176
  def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
177
  super().__init__(*args, **kwargs)
178
  self.reweighting_terms = [term.strip().lower() for term in reweighting_terms] if reweighting_terms else []
179
  self.reweighting_factor = reweighting_factor
180
 
 
181
  def compute_loss(self, model, inputs, return_outputs=False):
182
  loss, outputs = super().compute_loss(model, inputs, return_outputs=True)
183
  if self.reweighting_terms and self.reweighting_factor > 1.0:
@@ -191,6 +195,7 @@ class DebiasingSFTTrainer(SFTTrainer):
191
 
192
  @spaces.GPU
193
  class DeduplicatedIterableDataset(IterableDataset):
 
194
  def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
195
  super().__init__(ex_iterable=iter([]))
196
  self.dataset = dataset
@@ -203,6 +208,7 @@ class DeduplicatedIterableDataset(IterableDataset):
203
  elif hasattr(dataset, 'info'):
204
  self._info = dataset.info
205
 
 
206
  def __iter__(self):
207
  if self.method == 'Exacta':
208
  return self._exact_iter()
@@ -211,6 +217,7 @@ class DeduplicatedIterableDataset(IterableDataset):
211
  else:
212
  return iter(self.dataset)
213
 
 
214
  def _exact_iter(self):
215
  seen_texts = set()
216
  for example in self.dataset:
@@ -222,6 +229,7 @@ class DeduplicatedIterableDataset(IterableDataset):
222
  else:
223
  yield example
224
 
 
225
  def _minhash_iter(self):
226
  lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm)
227
  for i, example in enumerate(self.dataset):
@@ -355,6 +363,7 @@ def _get_filter_functions(**kwargs):
355
  if kwargs.get('enable_language_filter'):
356
  allowed_langs = [lang.strip() for lang in kwargs.get('allowed_languages', 'en').split(',')]
357
  lang_threshold = kwargs.get('language_detection_threshold', 0.95)
 
358
  def lang_filter(ex):
359
  text = ex.get(kwargs['text_col'], "")
360
  if not text or not isinstance(text, str) or len(text.split()) < 5: return True
@@ -366,6 +375,7 @@ def _get_filter_functions(**kwargs):
366
  filters.append(lang_filter)
367
  if kwargs.get('enable_toxicity_filter'):
368
  tox_threshold = kwargs.get('toxicity_threshold', 0.8)
 
369
  def tox_filter(ex):
370
  global _tox_pipe_singleton
371
  if _tox_pipe_singleton is None:
@@ -386,6 +396,7 @@ def _get_filter_functions(**kwargs):
386
  filters.append(lambda ex: _apply_coherence_filter(ex, kwargs['text_col'], char_rep_thresh, ngram_rep_thresh, entropy_thresh))
387
  if any([kwargs.get('enable_readability_filter'), kwargs.get('enable_stopword_filter'), kwargs.get('enable_uniqueness_filter')]):
388
  stop_words = set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should', 'could', 'can', 'may', 'might', 'must', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'what', 'which', 'who', 'when', 'where', 'why', 'how'])
 
389
  def stats_filter(ex):
390
  text = ex.get(kwargs['text_col'], "")
391
  if not isinstance(text, str) or not text: return True
@@ -483,6 +494,7 @@ def _apply_cda(dataset, text_col, cda_config_str):
483
  except (json.JSONDecodeError, ValueError) as e:
484
  logger.error(f"Configuración de CDA inválida: {e}.")
485
  return dataset
 
486
  def cda_generator():
487
  for example in dataset:
488
  original_text = example.get(text_col, "")
@@ -520,6 +532,7 @@ def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id
520
  except Exception as e:
521
  logger.error(f"No se pudieron cargar los modelos de traducción: {e}")
522
  return dataset
 
523
  def bt_generator():
524
  for example in dataset:
525
  yield example
@@ -551,6 +564,7 @@ def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples,
551
  if not seed_examples:
552
  logger.warning("Dataset original vacío, no se pueden generar datos sintéticos.")
553
  return None
 
554
  def synthetic_generator():
555
  for i in range(num_samples):
556
  seed_example = random.choice(seed_examples)
@@ -570,6 +584,7 @@ def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples,
570
  continue
571
  return IterableDataset.from_generator(synthetic_generator)
572
 
 
573
  def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
574
  safe_steps = int(steps_per_epoch_estimate or 10000)
575
  safe_batch_size = int(batch_size or 1)
@@ -610,6 +625,7 @@ def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn
610
  yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
611
  return None
612
 
 
613
  def _create_training_args(output_dir, repo_id, **kwargs):
614
  neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
615
  optim_args_dict = {}
@@ -876,6 +892,7 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
876
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
877
  model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
878
  model.config.pad_token_id = tokenizer.pad_token_id
 
879
  def preprocess(examples):
880
  return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
881
  train_dataset = train_dataset.map(preprocess, batched=True)
@@ -889,6 +906,7 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
889
  eval_dataset = update
890
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True)
891
  metric = hf_evaluate.load("accuracy")
 
892
  def compute_metrics(eval_pred):
893
  logits, labels = eval_pred
894
  predictions = np.argmax(logits, axis=-1)
@@ -916,6 +934,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
916
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, add_prefix_space=True)
917
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
918
  model = _generic_model_loader(model_name, AutoModelForTokenClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
 
919
  def tokenize_and_align_labels(examples):
920
  tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
921
  labels = []
@@ -943,6 +962,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
943
  eval_dataset = update
944
  if eval_dataset: eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True)
945
  metric = hf_evaluate.load("seqeval")
 
946
  def compute_metrics(p):
947
  predictions, labels = p
948
  predictions = np.argmax(predictions, axis=2)
@@ -973,6 +993,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
973
  model = _generic_model_loader(model_name, AutoModelForQuestionAnswering, **kwargs)
974
  max_length = 384
975
  doc_stride = 128
 
976
  def prepare_train_features(examples):
977
  tokenized_examples = tokenizer(
978
  examples["question"],
@@ -1049,6 +1070,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
1049
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
1050
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
1051
  model = _generic_model_loader(model_name, AutoModelForSeq2SeqLM, **kwargs)
 
1052
  def preprocess_function(examples):
1053
  inputs = [ex[kwargs['text_col']] for ex in examples["translation"]]
1054
  targets = [ex[kwargs['label_col']] for ex in examples["translation"]]
@@ -1068,6 +1090,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
1068
  eval_dataset = update
1069
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess_function, batched=True)
1070
  metric = hf_evaluate.load("sacrebleu")
 
1071
  def compute_metrics(eval_preds):
1072
  preds, labels = eval_preds
1073
  if isinstance(preds, tuple): preds = preds[0]
@@ -1125,6 +1148,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs, model_c
1125
  transforms.ToTensor(),
1126
  transforms.Normalize([0.5], [0.5]),
1127
  ])
 
1128
  def preprocess_train(examples):
1129
  images = [image.convert("RGB") for image in examples[image_col]]
1130
  examples["pixel_values"] = [image_transforms(image) for image in images]
@@ -1500,6 +1524,7 @@ def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in,
1500
  return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1501
  except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1502
 
 
1503
  def update_inference_ui(task_mode):
1504
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
1505
  is_text_gen = task_name == "text-generation"
@@ -1625,6 +1650,7 @@ def gradio_preview_data_wrapper(*args):
1625
  except Exception as e:
1626
  yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
1627
 
 
1628
  def toggle_training_mode_ui(is_scratch):
1629
  return (
1630
  gr.update(visible=not is_scratch),
@@ -1647,6 +1673,7 @@ def toggle_training_mode_ui(is_scratch):
1647
  gr.update(visible=is_scratch),
1648
  )
1649
 
 
1650
  def toggle_task_specific_ui(training_mode):
1651
  is_classification = "Classification" in training_mode
1652
  is_dpo = "DPO" in training_mode
@@ -1661,17 +1688,21 @@ def toggle_task_specific_ui(training_mode):
1661
  gr.update(visible=not is_diffusion)
1662
  )
1663
 
 
1664
  def toggle_sft_format_ui(format_style):
1665
  is_tool = format_style == "Razonamiento/Herramientas"
1666
  return gr.update(visible=is_tool)
1667
 
 
1668
  def toggle_auto_modules_ui(is_auto):
1669
  return gr.update(visible=not is_auto)
1670
 
 
1671
  def toggle_dataset_creator_ui(choice):
1672
  is_synth = choice == "Sintético"
1673
  return gr.update(visible=is_synth), gr.update(visible=not is_synth)
1674
 
 
1675
  def get_ao_username(token):
1676
  try:
1677
  api = HfApi(token=token)
@@ -1680,6 +1711,7 @@ def get_ao_username(token):
1680
  except Exception:
1681
  return "anonymous"
1682
 
 
1683
  def check_ao_model_exists(username, quantization_type, group_size, model_name, quantized_model_name, token):
1684
  try:
1685
  models = list_models(author=username, token=token)
@@ -1698,6 +1730,7 @@ def check_ao_model_exists(username, quantization_type, group_size, model_name, q
1698
  except Exception as e:
1699
  return f"Error checking model existence: {str(e)}"
1700
 
 
1701
  def create_ao_model_card(model_name, quantization_type, group_size, token):
1702
  try:
1703
  model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=token)
@@ -1726,6 +1759,7 @@ tags:
1726
  yaml_header += "\n\n# 📄 Original Model Info\n\n" + original_readme
1727
  return yaml_header
1728
 
 
1729
  def quantize_ao_model(model_name, quantization_type, group_size=128, token=None, progress=gr.Progress()):
1730
  print(f"Quantizing model: {quantization_type}")
1731
  progress(0, desc="Preparing Quantization")
@@ -1753,6 +1787,7 @@ def quantize_ao_model(model_name, quantization_type, group_size=128, token=None,
1753
  progress(0.45, desc="Quantization completed")
1754
  return model
1755
 
 
1756
  def save_ao_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, token=None, progress=gr.Progress()):
1757
  username = get_ao_username(token)
1758
  progress(0.50, desc="Preparing to push")
@@ -1814,6 +1849,7 @@ def quantize_and_save_ao(model_name, quantization_type, group_size, quantized_mo
1814
  except Exception as e:
1815
  return f"<div class='error-box'><h3>❌ Error</h3><p>{str(e)}</p></div>"
1816
 
 
1817
  def get_awq_default_repo_name(model_id: str, scheme: str) -> str:
1818
  if not model_id or not scheme:
1819
  return ""
@@ -2337,4 +2373,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
2337
  """)
2338
 
2339
  if __name__ == "__main__":
2340
- demo.queue(max_size=1).launch()
 
1
  import os
2
+ os.system("pip install spaces-0.1.0-py3-none-any.whl")
3
  os.system("pip install -U gradio")
4
  os.system("pip install -U bitsandbytes diffusers torchaudio torchvision torch transformers peft accelerate trl datasets")
5
+ os.system("pip install spaces")
6
  os.system("pip install gradio_huggingfacehub_search packaging torchao llmcompressor")
7
 
8
  import io
 
82
  logger = logging.getLogger(__name__)
83
  torch_dtype_auto = torch.float32
84
 
85
+ @spaces.GPU
86
  def _sanitize_model_name_for_yaml(model_name):
87
  name = model_name.split('/')[-1] if '/' in model_name else model_name
88
  sanitized = re.sub(r'[^a-zA-Z0-9\-_\.]', '-', name)
 
175
 
176
  @spaces.GPU
177
  class DebiasingSFTTrainer(SFTTrainer):
178
+ @spaces.GPU
179
  def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
180
  super().__init__(*args, **kwargs)
181
  self.reweighting_terms = [term.strip().lower() for term in reweighting_terms] if reweighting_terms else []
182
  self.reweighting_factor = reweighting_factor
183
 
184
+ @spaces.GPU
185
  def compute_loss(self, model, inputs, return_outputs=False):
186
  loss, outputs = super().compute_loss(model, inputs, return_outputs=True)
187
  if self.reweighting_terms and self.reweighting_factor > 1.0:
 
195
 
196
  @spaces.GPU
197
  class DeduplicatedIterableDataset(IterableDataset):
198
+ @spaces.GPU
199
  def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
200
  super().__init__(ex_iterable=iter([]))
201
  self.dataset = dataset
 
208
  elif hasattr(dataset, 'info'):
209
  self._info = dataset.info
210
 
211
+ @spaces.GPU
212
  def __iter__(self):
213
  if self.method == 'Exacta':
214
  return self._exact_iter()
 
217
  else:
218
  return iter(self.dataset)
219
 
220
+ @spaces.GPU
221
  def _exact_iter(self):
222
  seen_texts = set()
223
  for example in self.dataset:
 
229
  else:
230
  yield example
231
 
232
+ @spaces.GPU
233
  def _minhash_iter(self):
234
  lsh = MinHashLSH(threshold=self.threshold, num_perm=self.num_perm)
235
  for i, example in enumerate(self.dataset):
 
363
  if kwargs.get('enable_language_filter'):
364
  allowed_langs = [lang.strip() for lang in kwargs.get('allowed_languages', 'en').split(',')]
365
  lang_threshold = kwargs.get('language_detection_threshold', 0.95)
366
+ @spaces.GPU
367
  def lang_filter(ex):
368
  text = ex.get(kwargs['text_col'], "")
369
  if not text or not isinstance(text, str) or len(text.split()) < 5: return True
 
375
  filters.append(lang_filter)
376
  if kwargs.get('enable_toxicity_filter'):
377
  tox_threshold = kwargs.get('toxicity_threshold', 0.8)
378
+ @spaces.GPU
379
  def tox_filter(ex):
380
  global _tox_pipe_singleton
381
  if _tox_pipe_singleton is None:
 
396
  filters.append(lambda ex: _apply_coherence_filter(ex, kwargs['text_col'], char_rep_thresh, ngram_rep_thresh, entropy_thresh))
397
  if any([kwargs.get('enable_readability_filter'), kwargs.get('enable_stopword_filter'), kwargs.get('enable_uniqueness_filter')]):
398
  stop_words = set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should', 'could', 'can', 'may', 'might', 'must', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'what', 'which', 'who', 'when', 'where', 'why', 'how'])
399
+ @spaces.GPU
400
  def stats_filter(ex):
401
  text = ex.get(kwargs['text_col'], "")
402
  if not isinstance(text, str) or not text: return True
 
494
  except (json.JSONDecodeError, ValueError) as e:
495
  logger.error(f"Configuración de CDA inválida: {e}.")
496
  return dataset
497
+ @spaces.GPU
498
  def cda_generator():
499
  for example in dataset:
500
  original_text = example.get(text_col, "")
 
532
  except Exception as e:
533
  logger.error(f"No se pudieron cargar los modelos de traducción: {e}")
534
  return dataset
535
+ @spaces.GPU
536
  def bt_generator():
537
  for example in dataset:
538
  yield example
 
564
  if not seed_examples:
565
  logger.warning("Dataset original vacío, no se pueden generar datos sintéticos.")
566
  return None
567
+ @spaces.GPU
568
  def synthetic_generator():
569
  for i in range(num_samples):
570
  seed_example = random.choice(seed_examples)
 
584
  continue
585
  return IterableDataset.from_generator(synthetic_generator)
586
 
587
+ @spaces.GPU
588
  def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
589
  safe_steps = int(steps_per_epoch_estimate or 10000)
590
  safe_batch_size = int(batch_size or 1)
 
625
  yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
626
  return None
627
 
628
+ @spaces.GPU
629
  def _create_training_args(output_dir, repo_id, **kwargs):
630
  neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
631
  optim_args_dict = {}
 
892
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
893
  model = _generic_model_loader(model_name, AutoModelForSequenceClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
894
  model.config.pad_token_id = tokenizer.pad_token_id
895
+ @spaces.GPU
896
  def preprocess(examples):
897
  return tokenizer(examples[kwargs['text_col']], truncation=True, max_length=512)
898
  train_dataset = train_dataset.map(preprocess, batched=True)
 
906
  eval_dataset = update
907
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess, batched=True)
908
  metric = hf_evaluate.load("accuracy")
909
+ @spaces.GPU
910
  def compute_metrics(eval_pred):
911
  logits, labels = eval_pred
912
  predictions = np.argmax(logits, axis=-1)
 
934
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True, add_prefix_space=True)
935
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
936
  model = _generic_model_loader(model_name, AutoModelForTokenClassification, num_labels=len(labels), label2id=label2id, id2label=id2label, **kwargs)
937
+ @spaces.GPU
938
  def tokenize_and_align_labels(examples):
939
  tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
940
  labels = []
 
962
  eval_dataset = update
963
  if eval_dataset: eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True)
964
  metric = hf_evaluate.load("seqeval")
965
+ @spaces.GPU
966
  def compute_metrics(p):
967
  predictions, labels = p
968
  predictions = np.argmax(predictions, axis=2)
 
993
  model = _generic_model_loader(model_name, AutoModelForQuestionAnswering, **kwargs)
994
  max_length = 384
995
  doc_stride = 128
996
+ @spaces.GPU
997
  def prepare_train_features(examples):
998
  tokenized_examples = tokenizer(
999
  examples["question"],
 
1070
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
1071
  yield update_logs_fn(f"Cargando modelo '{model_name}'...", "Configuración")
1072
  model = _generic_model_loader(model_name, AutoModelForSeq2SeqLM, **kwargs)
1073
+ @spaces.GPU
1074
  def preprocess_function(examples):
1075
  inputs = [ex[kwargs['text_col']] for ex in examples["translation"]]
1076
  targets = [ex[kwargs['label_col']] for ex in examples["translation"]]
 
1090
  eval_dataset = update
1091
  if eval_dataset: eval_dataset = eval_dataset.map(preprocess_function, batched=True)
1092
  metric = hf_evaluate.load("sacrebleu")
1093
+ @spaces.GPU
1094
  def compute_metrics(eval_preds):
1095
  preds, labels = eval_preds
1096
  if isinstance(preds, tuple): preds = preds[0]
 
1148
  transforms.ToTensor(),
1149
  transforms.Normalize([0.5], [0.5]),
1150
  ])
1151
+ @spaces.GPU
1152
  def preprocess_train(examples):
1153
  images = [image.convert("RGB") for image in examples[image_col]]
1154
  examples["pixel_values"] = [image_transforms(image) for image in images]
 
1524
  return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1525
  except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1526
 
1527
+ @spaces.GPU
1528
  def update_inference_ui(task_mode):
1529
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
1530
  is_text_gen = task_name == "text-generation"
 
1650
  except Exception as e:
1651
  yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
1652
 
1653
+ @spaces.GPU
1654
  def toggle_training_mode_ui(is_scratch):
1655
  return (
1656
  gr.update(visible=not is_scratch),
 
1673
  gr.update(visible=is_scratch),
1674
  )
1675
 
1676
+ @spaces.GPU
1677
  def toggle_task_specific_ui(training_mode):
1678
  is_classification = "Classification" in training_mode
1679
  is_dpo = "DPO" in training_mode
 
1688
  gr.update(visible=not is_diffusion)
1689
  )
1690
 
1691
+ @spaces.GPU
1692
  def toggle_sft_format_ui(format_style):
1693
  is_tool = format_style == "Razonamiento/Herramientas"
1694
  return gr.update(visible=is_tool)
1695
 
1696
+ @spaces.GPU
1697
  def toggle_auto_modules_ui(is_auto):
1698
  return gr.update(visible=not is_auto)
1699
 
1700
+ @spaces.GPU
1701
  def toggle_dataset_creator_ui(choice):
1702
  is_synth = choice == "Sintético"
1703
  return gr.update(visible=is_synth), gr.update(visible=not is_synth)
1704
 
1705
+ @spaces.GPU
1706
  def get_ao_username(token):
1707
  try:
1708
  api = HfApi(token=token)
 
1711
  except Exception:
1712
  return "anonymous"
1713
 
1714
+ @spaces.GPU
1715
  def check_ao_model_exists(username, quantization_type, group_size, model_name, quantized_model_name, token):
1716
  try:
1717
  models = list_models(author=username, token=token)
 
1730
  except Exception as e:
1731
  return f"Error checking model existence: {str(e)}"
1732
 
1733
+ @spaces.GPU
1734
  def create_ao_model_card(model_name, quantization_type, group_size, token):
1735
  try:
1736
  model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=token)
 
1759
  yaml_header += "\n\n# 📄 Original Model Info\n\n" + original_readme
1760
  return yaml_header
1761
 
1762
+ @spaces.GPU
1763
  def quantize_ao_model(model_name, quantization_type, group_size=128, token=None, progress=gr.Progress()):
1764
  print(f"Quantizing model: {quantization_type}")
1765
  progress(0, desc="Preparing Quantization")
 
1787
  progress(0.45, desc="Quantization completed")
1788
  return model
1789
 
1790
+ @spaces.GPU
1791
  def save_ao_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, token=None, progress=gr.Progress()):
1792
  username = get_ao_username(token)
1793
  progress(0.50, desc="Preparing to push")
 
1849
  except Exception as e:
1850
  return f"<div class='error-box'><h3>❌ Error</h3><p>{str(e)}</p></div>"
1851
 
1852
+ @spaces.GPU
1853
  def get_awq_default_repo_name(model_id: str, scheme: str) -> str:
1854
  if not model_id or not scheme:
1855
  return ""
 
2373
  """)
2374
 
2375
  if __name__ == "__main__":
2376
+ demo.queue().launch(debug=True, share=True)