Ignaciohhhhggfgjfrffd commited on
Commit
3bb1f41
·
verified ·
1 Parent(s): ffde733

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +390 -14
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  os.system("pip install -U gradio")
3
  os.system("pip install -U bitsandbytes diffusers torchaudio torchvision torch transformers peft accelerate trl datasets")
4
- os.system("pip install spaces-0.1.0-py3-none-any.whl")
 
5
 
6
  import io
7
  import json
@@ -16,6 +17,7 @@ import importlib
16
  import random
17
  import re
18
  import ast
 
19
  from itertools import islice
20
  from pathlib import Path
21
  from collections import defaultdict
@@ -38,11 +40,11 @@ import textstat
38
  from datasketch import MinHash, MinHashLSH
39
  import gradio as gr
40
  from datasets import load_dataset, IterableDataset, Dataset as HFDataset, DatasetDict, interleave_datasets, Audio
41
- from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi, hf_hub_download, list_repo_files
42
  from transformers import (
43
  AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
44
  AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
45
- AutoModelForImageClassification,
46
  AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
47
  DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq,
48
  AutoProcessor, DataCollatorWithPadding, pipeline,
@@ -62,6 +64,19 @@ from diffusers import (
62
  get_scheduler as get_diffusers_scheduler, StableDiffusionPipeline as StableDiffusionText2ImagePipeline,
63
  StableDiffusionImg2ImgPipeline as StableDiffusionImage2ImagePipeline
64
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  logger = logging.getLogger(__name__)
67
  torch_dtype_auto = torch.float32
@@ -135,7 +150,27 @@ Este dataset fue creado utilizando la herramienta [AutoTrain-Advanced](https://h
135
  - **Modelo de Generación (si aplica):** `{generation_model}`
136
  - **Fecha de Creación:** {date}
137
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  _tox_pipe_singleton = None
 
139
  @spaces.GPU
140
  class DebiasingSFTTrainer(SFTTrainer):
141
  def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
@@ -153,6 +188,7 @@ class DebiasingSFTTrainer(SFTTrainer):
153
  loss *= self.reweighting_factor
154
  break
155
  return (loss, outputs) if return_outputs else loss
 
156
  @spaces.GPU
157
  class DeduplicatedIterableDataset(IterableDataset):
158
  def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
@@ -210,6 +246,7 @@ def hf_login(token):
210
  return f"✅ Conectado como: {user['name']}"
211
  except Exception as e:
212
  return f"❌ Error en la conexión: {e}"
 
213
  @spaces.GPU
214
  def _clean_text(example, text_col, **kwargs):
215
  text = example.get(text_col, "")
@@ -227,6 +264,7 @@ def _clean_text(example, text_col, **kwargs):
227
  text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '<IP_ADDRESS>', text)
228
  example[text_col] = text
229
  return example
 
230
  @spaces.GPU
231
  def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords):
232
  text = example.get(text_col, "")
@@ -240,6 +278,7 @@ def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, e
240
  if not word_counts or (max(word_counts.values()) / len(words)) > rep_threshold: return False
241
  lower_text = text.lower()
242
  return not any(keyword in lower_text for keyword in exclude_keywords)
 
243
  @spaces.GPU
244
  def _apply_coherence_filter(example, text_col, char_rep_threshold, ngram_rep_threshold, entropy_threshold):
245
  text = example.get(text_col, "")
@@ -306,6 +345,7 @@ def _apply_coherence_filter(example, text_col, char_rep_threshold, ngram_rep_thr
306
  if non_latin_chars > 2 and latin_chars > 10:
307
  return False
308
  return True
 
309
  @spaces.GPU
310
  def _get_filter_functions(**kwargs):
311
  filters = []
@@ -366,6 +406,7 @@ def _get_filter_functions(**kwargs):
366
  return True
367
  filters.append(stats_filter)
368
  return filters
 
369
  @spaces.GPU
370
  def _load_hf_streaming(ids, split="train", probabilities=None):
371
  streams = []
@@ -395,6 +436,7 @@ def _load_hf_streaming(ids, split="train", probabilities=None):
395
  logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
396
  probabilities = None
397
  return interleave_datasets(streams, probabilities=probabilities)
 
398
  @spaces.GPU
399
  def _load_uploaded_stream(files):
400
  all_rows = []
@@ -416,6 +458,7 @@ def _load_uploaded_stream(files):
416
  val_size = max(1, int(len(all_rows) * 0.01))
417
  random.shuffle(all_rows)
418
  return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []}
 
419
  @spaces.GPU
420
  def _guess_columns(sample):
421
  text_col, image_col, audio_col, label_col = "text", "image", "audio", "label"
@@ -432,6 +475,7 @@ def _guess_columns(sample):
432
  if "label" in keys: label_col = keys["label"]
433
  elif "labels" in keys: label_col = keys["labels"]
434
  return text_col, image_col, audio_col, label_col
 
435
  @spaces.GPU
436
  def _apply_cda(dataset, text_col, cda_config_str):
437
  try:
@@ -464,6 +508,7 @@ def _apply_cda(dataset, text_col, cda_config_str):
464
  next_texts.add(new_text)
465
  current_texts.update(next_texts)
466
  return IterableDataset.from_generator(cda_generator)
 
467
  @spaces.GPU
468
  def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
469
  if not ratio or ratio <= 0:
@@ -491,6 +536,7 @@ def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id
491
  except Exception as e:
492
  logger.warning(f"Error en retrotraducción: {e}")
493
  return IterableDataset.from_generator(bt_generator)
 
494
  @spaces.GPU
495
  def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
496
  if not num_samples or num_samples <= 0:
@@ -523,6 +569,7 @@ def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples,
523
  logger.warning(f"Error generando una muestra sintética: {e}")
524
  continue
525
  return IterableDataset.from_generator(synthetic_generator)
 
526
  def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
527
  safe_steps = int(steps_per_epoch_estimate or 10000)
528
  safe_batch_size = int(batch_size or 1)
@@ -540,6 +587,7 @@ def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, b
540
  layers = max(8, min(32, 8 + int(log_size * 1.5)))
541
  kv_heads = heads if is_gpt2_like else (max(1, heads // 4))
542
  return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads
 
543
  @spaces.GPU
544
  def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn):
545
  if eval_ds_id:
@@ -561,6 +609,7 @@ def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn
561
  return None
562
  yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
563
  return None
 
564
  def _create_training_args(output_dir, repo_id, **kwargs):
565
  neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
566
  optim_args_dict = {}
@@ -610,6 +659,7 @@ def _create_training_args(output_dir, repo_id, **kwargs):
610
  else:
611
  raise ValueError("Para datasets en streaming se requiere un valor positivo para 'Máximos Pasos de Entrenamiento'.")
612
  return TrainingArguments(**args_dict)
 
613
  @spaces.GPU
614
  def _generic_model_loader(model_name_or_path, model_class, **kwargs):
615
  config_kwargs = {"trust_remote_code": True}
@@ -627,6 +677,7 @@ def _generic_model_loader(model_name_or_path, model_class, **kwargs):
627
  model_kwargs.update({"num_labels": kwargs['num_labels'], "ignore_mismatched_sizes": True})
628
  model = model_class.from_pretrained(model_name_or_path, **model_kwargs)
629
  return model
 
630
  @spaces.GPU
631
  def _find_all_linear_names(model):
632
  cls = torch.nn.Linear
@@ -639,6 +690,7 @@ def _find_all_linear_names(model):
639
  lora_module_names.remove('lm_head')
640
  common_targets = {'q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'}
641
  return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
 
642
  @spaces.GPU
643
  def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
644
  if kwargs.get('sft_format_style') == "Conversacional":
@@ -672,9 +724,11 @@ def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
672
  return "\n".join([m['content'] for m in messages])
673
  return ""
674
  return example.get(text_col, "")
 
675
  @spaces.GPU
676
  def _dpo_formatting_func(example, **kwargs):
677
  return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")}
 
678
  @spaces.GPU
679
  def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
680
  model.eval()
@@ -699,6 +753,7 @@ def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
699
  break
700
  ppl = torch.exp(torch.stack(nlls).mean())
701
  return ppl.item()
 
702
  @spaces.GPU
703
  def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
704
  adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
@@ -730,6 +785,7 @@ def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combinati
730
  tokenizer.save_pretrained(temp_dir)
731
  yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}."
732
  return temp_dir
 
733
  @spaces.GPU
734
  def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
735
  yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
@@ -750,6 +806,7 @@ def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_c
750
  yield update_logs_fn("Subiendo al Hub...", "Subiendo")
751
  upload_folder(folder_path=output_dir, repo_id=repo_id, commit_message="Fin de entrenamiento")
752
  return output_dir, final_metrics
 
753
  @spaces.GPU
754
  def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
755
  output_dir = tempfile.mkdtemp()
@@ -803,6 +860,7 @@ def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card
803
  return final_model_path, final_metrics
804
  except Exception as e:
805
  raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}")
 
806
  @spaces.GPU
807
  def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
808
  output_dir = tempfile.mkdtemp()
@@ -845,6 +903,7 @@ def train_sequence_classification(model_name, train_dataset, repo_id, update_log
845
  return final_model_path, final_metrics
846
  except Exception as e:
847
  raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}")
 
848
  @spaces.GPU
849
  def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
850
  output_dir = tempfile.mkdtemp()
@@ -902,6 +961,7 @@ def train_token_classification(model_name, train_dataset, repo_id, update_logs_f
902
  return final_model_path, final_metrics
903
  except Exception as e:
904
  raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}")
 
905
  @spaces.GPU
906
  def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
907
  output_dir = tempfile.mkdtemp()
@@ -979,6 +1039,7 @@ def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn,
979
  return final_model_path, final_metrics
980
  except Exception as e:
981
  raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}")
 
982
  @spaces.GPU
983
  def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
984
  output_dir = tempfile.mkdtemp()
@@ -1030,6 +1091,7 @@ def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card
1030
  return final_model_path, final_metrics
1031
  except Exception as e:
1032
  raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}")
 
1033
  @spaces.GPU
1034
  def train_text_to_image(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs):
1035
  output_dir = tempfile.mkdtemp()
@@ -1181,6 +1243,7 @@ def train_text_to_image(model_name, train_dataset, repo_id, update_logs, model_c
1181
  except Exception as e:
1182
  yield update_logs(f"❌ Error en entrenamiento Text-to-Image: {str(e)}", "Error")
1183
  raise Exception(f"Error en Text-to-Image: {e}\n{traceback.format_exc()}")
 
1184
  @spaces.GPU
1185
  def _get_data_processing_pipeline(**kwargs):
1186
  hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]
@@ -1243,6 +1306,7 @@ def _get_data_processing_pipeline(**kwargs):
1243
  num_perm=int(kwargs.get('minhash_num_perm', 128))
1244
  )
1245
  return train_dataset, kwargs
 
1246
  @spaces.GPU
1247
  def _train_and_upload(progress=gr.Progress(), **kwargs):
1248
  logs, repo_link, final_model_path, final_metrics = "", "", None, {}
@@ -1411,6 +1475,7 @@ def _train_and_upload(progress=gr.Progress(), **kwargs):
1411
  gr.update(value="Iniciar Entrenamiento", interactive=True),
1412
  gr.update(visible=False)
1413
  )
 
1414
  @spaces.GPU
1415
  def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens):
1416
  if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update()
@@ -1434,6 +1499,7 @@ def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in,
1434
  result = pipe(input_data)
1435
  return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1436
  except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
 
1437
  def update_inference_ui(task_mode):
1438
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
1439
  is_text_gen = task_name == "text-generation"
@@ -1449,6 +1515,7 @@ def update_inference_ui(task_mode):
1449
  gr.update(visible=show_audio),
1450
  gr.update(visible=is_text_gen)
1451
  )
 
1452
  @spaces.GPU
1453
  def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()):
1454
  if not hf_token:
@@ -1510,10 +1577,12 @@ def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, s
1510
  return f"✅ Dataset creado y subido exitosamente a {repo_id}", f"### ✅ [Dataset Disponible: Visita el Repositorio]({dataset_link})"
1511
  except Exception as e:
1512
  return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", ""
 
1513
  @spaces.GPU
1514
  def gradio_train_wrapper(*args):
1515
  kwargs = dict(zip(all_input_components_dict.keys(), args))
1516
  yield from _train_and_upload(**kwargs)
 
1517
  @spaces.GPU
1518
  def gradio_preview_data_wrapper(*args):
1519
  kwargs = dict(zip(all_input_components_dict.keys(), args))
@@ -1555,6 +1624,7 @@ def gradio_preview_data_wrapper(*args):
1555
  yield preview_text
1556
  except Exception as e:
1557
  yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
 
1558
  def toggle_training_mode_ui(is_scratch):
1559
  return (
1560
  gr.update(visible=not is_scratch),
@@ -1576,6 +1646,7 @@ def toggle_training_mode_ui(is_scratch):
1576
  gr.update(visible=is_scratch),
1577
  gr.update(visible=is_scratch),
1578
  )
 
1579
  def toggle_task_specific_ui(training_mode):
1580
  is_classification = "Classification" in training_mode
1581
  is_dpo = "DPO" in training_mode
@@ -1589,18 +1660,284 @@ def toggle_task_specific_ui(training_mode):
1589
  gr.update(visible=is_diffusion),
1590
  gr.update(visible=not is_diffusion)
1591
  )
 
1592
  def toggle_sft_format_ui(format_style):
1593
  is_tool = format_style == "Razonamiento/Herramientas"
1594
  return gr.update(visible=is_tool)
 
1595
  def toggle_auto_modules_ui(is_auto):
1596
  return gr.update(visible=not is_auto)
 
1597
  def toggle_dataset_creator_ui(choice):
1598
  is_synth = choice == "Sintético"
1599
  return gr.update(visible=is_synth), gr.update(visible=not is_synth)
1600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1601
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1602
- gr.Markdown("# 🚀 AutoTrain-Advanced: Tu Plataforma de Entrenamiento de Modelos")
1603
- gr.Markdown("### Una interfaz completa para fine-tuning y PEFT (LoRA).")
1604
 
1605
  with gr.Tab("1. Autenticación"):
1606
  gr.Markdown("#### Conecta tu cuenta de Hugging Face para guardar y cargar modelos.")
@@ -1916,7 +2253,50 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1916
  inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
1917
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1918
  )
1919
- with gr.Tab("5. Explicación del Código y Mecanismos Avanzados"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1920
  gr.Markdown("""
1921
  ### 🧠 Explicación del Código y Mecanismos Avanzados
1922
  """)
@@ -1944,14 +2324,10 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1944
  * Task-Specific Heads: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`).
1945
  * Seq2Seq: For translation/summarization tasks, using `Seq2SeqTrainer`.
1946
  """)
1947
- gr.Markdown("#### 4. MODEL INITIALIZATION & ADVANCED TECHNIQUES")
1948
  gr.Markdown("""
1949
- * Model From Scratch: Allows initializing a model (e.g., Llama, Mistral) from a config rather than a pre-trained checkpoint, with optional auto-configuration based on expected training scale.
1950
- * Manual Model Configuration: When training from scratch, users can manually specify low-level configuration parameters (e.g., `vocab_size`, `hidden_size`, `num_hidden_layers`) instead of relying on the automatic scaling based on training steps.
1951
- * Multi-Adapter Merging: Advanced feature to combine multiple existing LoRA adapters into a single, new adapter using weighted averaging (`slerp`, `linear`, etc.).
1952
- * DoRA (Weight-Decomposed Low-Rank Adaptation): A more advanced version of LoRA that can lead to better performance.
1953
- * RSLora (Rank-Stabilized LoRA): A variant of LoRA that adjusts the learning rate based on the rank, improving stability.
1954
- * NEFTune: Adds noise to the embedding layer during training, which can improve the performance of the fine-tuned model.
1955
  """)
1956
  gr.Markdown("#### 5. OUTPUT & DEPLOYMENT")
1957
  gr.Markdown("""
@@ -1961,4 +2337,4 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1961
  """)
1962
 
1963
  if __name__ == "__main__":
1964
- demo.queue().launch(debug=True, share=True)
 
1
  import os
2
  os.system("pip install -U gradio")
3
  os.system("pip install -U bitsandbytes diffusers torchaudio torchvision torch transformers peft accelerate trl datasets")
4
+ os.system("pip install spaces")
5
+ os.system("pip install gradio_huggingfacehub_search packaging torchao llmcompressor")
6
 
7
  import io
8
  import json
 
17
  import random
18
  import re
19
  import ast
20
+ import shutil
21
  from itertools import islice
22
  from pathlib import Path
23
  from collections import defaultdict
 
40
  from datasketch import MinHash, MinHashLSH
41
  import gradio as gr
42
  from datasets import load_dataset, IterableDataset, Dataset as HFDataset, DatasetDict, interleave_datasets, Audio
43
+ from huggingface_hub import login, whoami, create_repo, upload_folder, HfApi, hf_hub_download, list_repo_files, snapshot_download, list_models
44
  from transformers import (
45
  AutoModelForCausalLM, AutoTokenizer, AutoConfig, TrainingArguments, Trainer,
46
  AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer,
47
+ AutoModelForImageClassification, AutoModel, TorchAoConfig,
48
  AutoImageProcessor, AutoModelForAudioClassification, AutoFeatureExtractor, AutoModelForTokenClassification,
49
  DataCollatorForTokenClassification, AutoModelForQuestionAnswering, AutoModelForSpeechSeq2Seq,
50
  AutoProcessor, DataCollatorWithPadding, pipeline,
 
64
  get_scheduler as get_diffusers_scheduler, StableDiffusionPipeline as StableDiffusionText2ImagePipeline,
65
  StableDiffusionImg2ImgPipeline as StableDiffusionImage2ImagePipeline
66
  )
67
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
68
+ from packaging import version
69
+ from torchao.quantization import (
70
+ Int4WeightOnlyConfig,
71
+ Int8WeightOnlyConfig,
72
+ Int8DynamicActivationInt8WeightConfig,
73
+ Float8WeightOnlyConfig,
74
+ Float8DynamicActivationFloat8WeightConfig,
75
+ GemliteUIntXWeightOnlyConfig,
76
+ )
77
+ from torchao.dtypes import Int4CPULayout
78
+ from llmcompressor import oneshot
79
+ from llmcompressor.modifiers.awq import AWQModifier
80
 
81
  logger = logging.getLogger(__name__)
82
  torch_dtype_auto = torch.float32
 
150
  - **Modelo de Generación (si aplica):** `{generation_model}`
151
  - **Fecha de Creación:** {date}
152
  """
153
+
154
+ MAP_QUANT_TYPE_TO_NAME = {
155
+ "Int4WeightOnly": "int4wo",
156
+ "GemliteUIntXWeightOnly": "intxwo-gemlite",
157
+ "Int8WeightOnly": "int8wo",
158
+ "Int8DynamicActivationInt8Weight": "int8da8w8",
159
+ "Float8WeightOnly": "float8wo",
160
+ "Float8DynamicActivationFloat8Weight": "float8da8w8",
161
+ "autoquant": "autoquant",
162
+ }
163
+ MAP_QUANT_TYPE_TO_CONFIG = {
164
+ "Int4WeightOnly": Int4WeightOnlyConfig,
165
+ "GemliteUIntXWeightOnly": GemliteUIntXWeightOnlyConfig,
166
+ "Int8WeightOnly": Int8WeightOnlyConfig,
167
+ "Int8DynamicActivationInt8Weight": Int8DynamicActivationInt8WeightConfig,
168
+ "Float8WeightOnly": Float8WeightOnlyConfig,
169
+ "Float8DynamicActivationFloat8Weight": Float8DynamicActivationFloat8WeightConfig,
170
+ }
171
+
172
  _tox_pipe_singleton = None
173
+
174
  @spaces.GPU
175
  class DebiasingSFTTrainer(SFTTrainer):
176
  def __init__(self, *args, reweighting_terms=None, reweighting_factor=1.0, **kwargs):
 
188
  loss *= self.reweighting_factor
189
  break
190
  return (loss, outputs) if return_outputs else loss
191
+
192
  @spaces.GPU
193
  class DeduplicatedIterableDataset(IterableDataset):
194
  def __init__(self, dataset, text_col, method, threshold=0.85, num_perm=128):
 
246
  return f"✅ Conectado como: {user['name']}"
247
  except Exception as e:
248
  return f"❌ Error en la conexión: {e}"
249
+
250
  @spaces.GPU
251
  def _clean_text(example, text_col, **kwargs):
252
  text = example.get(text_col, "")
 
264
  text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '<IP_ADDRESS>', text)
265
  example[text_col] = text
266
  return example
267
+
268
  @spaces.GPU
269
  def _apply_quality_filters(example, text_col, min_len, max_len, rep_threshold, exclude_keywords):
270
  text = example.get(text_col, "")
 
278
  if not word_counts or (max(word_counts.values()) / len(words)) > rep_threshold: return False
279
  lower_text = text.lower()
280
  return not any(keyword in lower_text for keyword in exclude_keywords)
281
+
282
  @spaces.GPU
283
  def _apply_coherence_filter(example, text_col, char_rep_threshold, ngram_rep_threshold, entropy_threshold):
284
  text = example.get(text_col, "")
 
345
  if non_latin_chars > 2 and latin_chars > 10:
346
  return False
347
  return True
348
+
349
  @spaces.GPU
350
  def _get_filter_functions(**kwargs):
351
  filters = []
 
406
  return True
407
  filters.append(stats_filter)
408
  return filters
409
+
410
  @spaces.GPU
411
  def _load_hf_streaming(ids, split="train", probabilities=None):
412
  streams = []
 
436
  logger.warning(f"Number of probabilities ({len(probabilities)}) does not match number of valid datasets ({len(streams)}). Ignoring weights.")
437
  probabilities = None
438
  return interleave_datasets(streams, probabilities=probabilities)
439
+
440
  @spaces.GPU
441
  def _load_uploaded_stream(files):
442
  all_rows = []
 
458
  val_size = max(1, int(len(all_rows) * 0.01))
459
  random.shuffle(all_rows)
460
  return {"train": all_rows[:-val_size] if val_size > 0 else all_rows, "validation": all_rows[-val_size:] if val_size > 0 else []}
461
+
462
  @spaces.GPU
463
  def _guess_columns(sample):
464
  text_col, image_col, audio_col, label_col = "text", "image", "audio", "label"
 
475
  if "label" in keys: label_col = keys["label"]
476
  elif "labels" in keys: label_col = keys["labels"]
477
  return text_col, image_col, audio_col, label_col
478
+
479
  @spaces.GPU
480
  def _apply_cda(dataset, text_col, cda_config_str):
481
  try:
 
508
  next_texts.add(new_text)
509
  current_texts.update(next_texts)
510
  return IterableDataset.from_generator(cda_generator)
511
+
512
  @spaces.GPU
513
  def _apply_back_translation(dataset, text_col, ratio, model_id, reverse_model_id):
514
  if not ratio or ratio <= 0:
 
536
  except Exception as e:
537
  logger.warning(f"Error en retrotraducción: {e}")
538
  return IterableDataset.from_generator(bt_generator)
539
+
540
  @spaces.GPU
541
  def _generate_synthetic_data(original_dataset, text_col, model_id, num_samples, prompt_template):
542
  if not num_samples or num_samples <= 0:
 
569
  logger.warning(f"Error generando una muestra sintética: {e}")
570
  continue
571
  return IterableDataset.from_generator(synthetic_generator)
572
+
573
  def _calculate_auto_config(block_size, is_gpt2_like, steps_per_epoch_estimate, batch_size, gradient_accumulation):
574
  safe_steps = int(steps_per_epoch_estimate or 10000)
575
  safe_batch_size = int(batch_size or 1)
 
587
  layers = max(8, min(32, 8 + int(log_size * 1.5)))
588
  kv_heads = heads if is_gpt2_like else (max(1, heads // 4))
589
  return vocab_size, hidden_size, hidden_size * 2, layers, heads, safe_block_size, False, kv_heads
590
+
591
  @spaces.GPU
592
  def _get_eval_dataset(train_ds_id, eval_ds_id, uploaded_val_data, update_logs_fn):
593
  if eval_ds_id:
 
609
  return None
610
  yield update_logs_fn("No se proporcionó dataset de evaluación. Omitiendo.", "Evaluación")
611
  return None
612
+
613
  def _create_training_args(output_dir, repo_id, **kwargs):
614
  neftune_alpha = float(kwargs.get('neftune_noise_alpha', 0.0))
615
  optim_args_dict = {}
 
659
  else:
660
  raise ValueError("Para datasets en streaming se requiere un valor positivo para 'Máximos Pasos de Entrenamiento'.")
661
  return TrainingArguments(**args_dict)
662
+
663
  @spaces.GPU
664
  def _generic_model_loader(model_name_or_path, model_class, **kwargs):
665
  config_kwargs = {"trust_remote_code": True}
 
677
  model_kwargs.update({"num_labels": kwargs['num_labels'], "ignore_mismatched_sizes": True})
678
  model = model_class.from_pretrained(model_name_or_path, **model_kwargs)
679
  return model
680
+
681
  @spaces.GPU
682
  def _find_all_linear_names(model):
683
  cls = torch.nn.Linear
 
690
  lora_module_names.remove('lm_head')
691
  common_targets = {'q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'}
692
  return list(lora_module_names.intersection(common_targets)) or list(lora_module_names)
693
+
694
  @spaces.GPU
695
  def _sft_formatting_func(example, text_col, tokenizer, **kwargs):
696
  if kwargs.get('sft_format_style') == "Conversacional":
 
724
  return "\n".join([m['content'] for m in messages])
725
  return ""
726
  return example.get(text_col, "")
727
+
728
  @spaces.GPU
729
  def _dpo_formatting_func(example, **kwargs):
730
  return {"prompt": example.get(kwargs.get('prompt_col_input', 'prompt'), ""), "chosen": example.get(kwargs.get('dpo_chosen_col_input', 'chosen'), ""), "rejected": example.get(kwargs.get('dpo_rejected_col_input', 'rejected'), "")}
731
+
732
  @spaces.GPU
733
  def _evaluate_perplexity(model, tokenizer, eval_dataset, text_col):
734
  model.eval()
 
753
  break
754
  ppl = torch.exp(torch.stack(nlls).mean())
755
  return ppl.item()
756
+
757
  @spaces.GPU
758
  def _merge_multiple_loras(base_model_id, adapter_ids_str, weights_str, combination_type):
759
  adapter_ids = [s.strip() for s in adapter_ids_str.split(',') if s.strip()]
 
785
  tokenizer.save_pretrained(temp_dir)
786
  yield f"Fusión de adaptadores completada. El entrenamiento continuará con el modelo fusionado en {temp_dir}."
787
  return temp_dir
788
+
789
  @spaces.GPU
790
  def _run_trainer_and_upload(trainer, tokenizer, repo_id, update_logs_fn, model_card_content, **kwargs):
791
  yield update_logs_fn("Iniciando ciclo de entrenamiento...", "Entrenando")
 
806
  yield update_logs_fn("Subiendo al Hub...", "Subiendo")
807
  upload_folder(folder_path=output_dir, repo_id=repo_id, commit_message="Fin de entrenamiento")
808
  return output_dir, final_metrics
809
+
810
  @spaces.GPU
811
  def train_sft_dpo(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
812
  output_dir = tempfile.mkdtemp()
 
860
  return final_model_path, final_metrics
861
  except Exception as e:
862
  raise Exception(f"Error en {'DPO' if is_dpo else 'SFT'}: {e}\n{traceback.format_exc()}")
863
+
864
  @spaces.GPU
865
  def train_sequence_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
866
  output_dir = tempfile.mkdtemp()
 
903
  return final_model_path, final_metrics
904
  except Exception as e:
905
  raise Exception(f"Error en Sequence Classification: {e}\n{traceback.format_exc()}")
906
+
907
  @spaces.GPU
908
  def train_token_classification(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
909
  output_dir = tempfile.mkdtemp()
 
961
  return final_model_path, final_metrics
962
  except Exception as e:
963
  raise Exception(f"Error en Token Classification: {e}\n{traceback.format_exc()}")
964
+
965
  @spaces.GPU
966
  def train_question_answering(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
967
  output_dir = tempfile.mkdtemp()
 
1039
  return final_model_path, final_metrics
1040
  except Exception as e:
1041
  raise Exception(f"Error en Question Answering: {e}\n{traceback.format_exc()}")
1042
+
1043
  @spaces.GPU
1044
  def train_seq2seq(model_name, train_dataset, repo_id, update_logs_fn, model_card_content, **kwargs):
1045
  output_dir = tempfile.mkdtemp()
 
1091
  return final_model_path, final_metrics
1092
  except Exception as e:
1093
  raise Exception(f"Error en Seq2Seq: {e}\n{traceback.format_exc()}")
1094
+
1095
  @spaces.GPU
1096
  def train_text_to_image(model_name, train_dataset, repo_id, update_logs, model_card_content, **kwargs):
1097
  output_dir = tempfile.mkdtemp()
 
1243
  except Exception as e:
1244
  yield update_logs(f"❌ Error en entrenamiento Text-to-Image: {str(e)}", "Error")
1245
  raise Exception(f"Error en Text-to-Image: {e}\n{traceback.format_exc()}")
1246
+
1247
  @spaces.GPU
1248
  def _get_data_processing_pipeline(**kwargs):
1249
  hf_ids = [x.strip() for x in (kwargs.get('datasets_hf_text') or "").split(",") if x.strip()]
 
1306
  num_perm=int(kwargs.get('minhash_num_perm', 128))
1307
  )
1308
  return train_dataset, kwargs
1309
+
1310
  @spaces.GPU
1311
  def _train_and_upload(progress=gr.Progress(), **kwargs):
1312
  logs, repo_link, final_model_path, final_metrics = "", "", None, {}
 
1475
  gr.update(value="Iniciar Entrenamiento", interactive=True),
1476
  gr.update(visible=False)
1477
  )
1478
+
1479
  @spaces.GPU
1480
  def run_inference(task_mode, model_id, text_in, context_in, image_in, audio_in, temperature, top_p, max_new_tokens):
1481
  if not model_id: return "Por favor, introduce un ID de modelo del Hub.", model_id, gr.update(), gr.update(), gr.update(), gr.update()
 
1499
  result = pipe(input_data)
1500
  return f"Resultado:\n\n{json.dumps(result, indent=2, ensure_ascii=False)}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1501
  except Exception as e: return f"Error en Inferencia: {e}\n{traceback.format_exc()}", model_id, gr.update(), gr.update(), gr.update(), gr.update()
1502
+
1503
  def update_inference_ui(task_mode):
1504
  task_name = TASK_TO_PIPELINE_MAP.get(task_mode, "")
1505
  is_text_gen = task_name == "text-generation"
 
1515
  gr.update(visible=show_audio),
1516
  gr.update(visible=is_text_gen)
1517
  )
1518
+
1519
  @spaces.GPU
1520
  def create_and_upload_dataset(hf_token, repo_name, creation_type, synth_model, synth_prompt, synth_num_samples, file_uploads, progress=gr.Progress()):
1521
  if not hf_token:
 
1577
  return f"✅ Dataset creado y subido exitosamente a {repo_id}", f"### ✅ [Dataset Disponible: Visita el Repositorio]({dataset_link})"
1578
  except Exception as e:
1579
  return f"❌ Error fatal durante la creación del dataset: {e}\n{traceback.format_exc()}", ""
1580
+
1581
  @spaces.GPU
1582
  def gradio_train_wrapper(*args):
1583
  kwargs = dict(zip(all_input_components_dict.keys(), args))
1584
  yield from _train_and_upload(**kwargs)
1585
+
1586
  @spaces.GPU
1587
  def gradio_preview_data_wrapper(*args):
1588
  kwargs = dict(zip(all_input_components_dict.keys(), args))
 
1624
  yield preview_text
1625
  except Exception as e:
1626
  yield f"Error al generar la vista previa: {e}\n{traceback.format_exc()}"
1627
+
1628
  def toggle_training_mode_ui(is_scratch):
1629
  return (
1630
  gr.update(visible=not is_scratch),
 
1646
  gr.update(visible=is_scratch),
1647
  gr.update(visible=is_scratch),
1648
  )
1649
+
1650
  def toggle_task_specific_ui(training_mode):
1651
  is_classification = "Classification" in training_mode
1652
  is_dpo = "DPO" in training_mode
 
1660
  gr.update(visible=is_diffusion),
1661
  gr.update(visible=not is_diffusion)
1662
  )
1663
+
1664
  def toggle_sft_format_ui(format_style):
1665
  is_tool = format_style == "Razonamiento/Herramientas"
1666
  return gr.update(visible=is_tool)
1667
+
1668
  def toggle_auto_modules_ui(is_auto):
1669
  return gr.update(visible=not is_auto)
1670
+
1671
  def toggle_dataset_creator_ui(choice):
1672
  is_synth = choice == "Sintético"
1673
  return gr.update(visible=is_synth), gr.update(visible=not is_synth)
1674
 
1675
+ def get_ao_username(token):
1676
+ try:
1677
+ api = HfApi(token=token)
1678
+ info = api.whoami()
1679
+ return info["name"]
1680
+ except Exception:
1681
+ return "anonymous"
1682
+
1683
+ def check_ao_model_exists(username, quantization_type, group_size, model_name, quantized_model_name, token):
1684
+ try:
1685
+ models = list_models(author=username, token=token)
1686
+ model_names = [model.id for model in models]
1687
+ if quantized_model_name:
1688
+ repo_name = f"{username}/{quantized_model_name}"
1689
+ else:
1690
+ if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and group_size is not None:
1691
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
1692
+ else:
1693
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
1694
+ if repo_name in model_names:
1695
+ return f"Model '{repo_name}' already exists in your repository."
1696
+ else:
1697
+ return None
1698
+ except Exception as e:
1699
+ return f"Error checking model existence: {str(e)}"
1700
+
1701
+ def create_ao_model_card(model_name, quantization_type, group_size, token):
1702
+ try:
1703
+ model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model", token=token)
1704
+ readme_path = os.path.join(model_path, "README.md")
1705
+ original_readme = ""
1706
+ if os.path.exists(readme_path):
1707
+ with open(readme_path, "r", encoding="utf-8") as f:
1708
+ original_readme = f.read()
1709
+ except Exception:
1710
+ original_readme = ""
1711
+
1712
+ yaml_header = f"""---
1713
+ base_model:
1714
+ - {model_name}
1715
+ tags:
1716
+ - torchao-my-repo
1717
+ ---
1718
+ # {model_name} (Quantized)
1719
+
1720
+ ## Quantization Details
1721
+ - **Quantization Type**: {quantization_type}
1722
+ - **Group Size**: {group_size}
1723
+
1724
+ """
1725
+ if original_readme:
1726
+ yaml_header += "\n\n# 📄 Original Model Info\n\n" + original_readme
1727
+ return yaml_header
1728
+
1729
+ def quantize_ao_model(model_name, quantization_type, group_size=128, token=None, progress=gr.Progress()):
1730
+ print(f"Quantizing model: {quantization_type}")
1731
+ progress(0, desc="Preparing Quantization")
1732
+
1733
+ if quantization_type == "GemliteUIntXWeightOnly":
1734
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size)
1735
+ elif quantization_type == "Int4WeightOnly":
1736
+ from torchao.dtypes import Int4CPULayout
1737
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type](group_size=group_size, layout=Int4CPULayout())
1738
+ elif quantization_type == "autoquant":
1739
+ quant_config = "autoquant"
1740
+ else:
1741
+ quant_config = MAP_QUANT_TYPE_TO_CONFIG[quantization_type]()
1742
+
1743
+ quantization_config = TorchAoConfig(quant_config)
1744
+ progress(0.10, desc="Quantizing model")
1745
+
1746
+ model = AutoModel.from_pretrained(
1747
+ model_name,
1748
+ torch_dtype="auto",
1749
+ quantization_config=quantization_config,
1750
+ device_map="cpu",
1751
+ token=token,
1752
+ )
1753
+ progress(0.45, desc="Quantization completed")
1754
+ return model
1755
+
1756
+ def save_ao_model(model, model_name, quantization_type, group_size=128, quantized_model_name=None, public=True, token=None, progress=gr.Progress()):
1757
+ username = get_ao_username(token)
1758
+ progress(0.50, desc="Preparing to push")
1759
+ print("Saving quantized model")
1760
+
1761
+ with tempfile.TemporaryDirectory() as tmpdirname:
1762
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
1763
+ tokenizer.save_pretrained(tmpdirname)
1764
+ model.save_pretrained(tmpdirname, safe_serialization=False)
1765
+
1766
+ if quantized_model_name:
1767
+ repo_name = f"{username}/{quantized_model_name}"
1768
+ else:
1769
+ if quantization_type in ["Int4WeightOnly", "GemliteUIntXWeightOnly"] and (group_size is not None):
1770
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}-gs{group_size}"
1771
+ else:
1772
+ repo_name = f"{username}/{model_name.split('/')[-1]}-ao-{MAP_QUANT_TYPE_TO_NAME[quantization_type]}"
1773
+
1774
+ progress(0.70, desc="Creating model card")
1775
+ model_card = create_ao_model_card(model_name, quantization_type, group_size, token)
1776
+ with open(os.path.join(tmpdirname, "README.md"), "w") as f:
1777
+ f.write(model_card)
1778
+
1779
+ api = HfApi(token=token)
1780
+ api.create_repo(repo_name, exist_ok=True, private=not public)
1781
+ progress(0.80, desc="Pushing to Hub")
1782
+ api.upload_folder(folder_path=tmpdirname, repo_id=repo_name, repo_type="model")
1783
+ progress(1.00, desc="Done")
1784
+
1785
+ repo_link = f"""
1786
+ <div class="repo-link">
1787
+ <h3>🔗 Repository Link</h3>
1788
+ <p>Find your repo here: <a href="https://huggingface.co/{repo_name}" target="_blank">{repo_name}</a></p>
1789
+ </div>
1790
+ """
1791
+ return f"<h1>🎉 Quantization Completed</h1><br/>{repo_link}"
1792
+
1793
+ @spaces.GPU
1794
+ def quantize_and_save_ao(model_name, quantization_type, group_size, quantized_model_name, public, hf_token):
1795
+ username = get_ao_username(hf_token)
1796
+ if not username or username == "anonymous":
1797
+ return "<div class='error-box'><h3>❌ Authentication Error</h3><p>Invalid or missing HF_TOKEN.</p></div>"
1798
+
1799
+ if group_size and str(group_size).strip():
1800
+ try:
1801
+ group_size = int(group_size)
1802
+ except ValueError:
1803
+ group_size = None
1804
+ else:
1805
+ group_size = None
1806
+
1807
+ exists_message = check_ao_model_exists(username, quantization_type, group_size, model_name, quantized_model_name, hf_token)
1808
+ if exists_message:
1809
+ return f"<div class='warning-box'><h3>⚠️ Model Already Exists</h3><p>{exists_message}</p></div>"
1810
+
1811
+ try:
1812
+ quantized_model = quantize_ao_model(model_name, quantization_type, group_size, token=hf_token)
1813
+ return save_ao_model(quantized_model, model_name, quantization_type, group_size, quantized_model_name, public, token=hf_token)
1814
+ except Exception as e:
1815
+ return f"<div class='error-box'><h3>❌ Error</h3><p>{str(e)}</p></div>"
1816
+
1817
+ def get_awq_default_repo_name(model_id: str, scheme: str) -> str:
1818
+ if not model_id or not scheme:
1819
+ return ""
1820
+ model_base_name = Path(model_id).name
1821
+ suggested_name = f"{model_base_name}-AWQ-{scheme}"
1822
+ return f"<your-username>/{suggested_name}"
1823
+
1824
+ @spaces.GPU
1825
+ def run_awq_compression(
1826
+ hf_token: str,
1827
+ model_id: str,
1828
+ scheme: str,
1829
+ ignore_lm_head: bool,
1830
+ num_calib_samples: float,
1831
+ max_seq_len: float,
1832
+ pipeline_mode: str,
1833
+ upload_repo: str,
1834
+ progress=gr.Progress(track_tqdm=True),
1835
+ ):
1836
+ logs = []
1837
+
1838
+ def log(msg: str) -> str:
1839
+ logs.append(msg)
1840
+ return "\n".join(logs)
1841
+
1842
+ if not model_id:
1843
+ yield log("Error: Please provide a source model id (e.g. meta-llama/Llama-3.3-70B-Instruct).")
1844
+ return
1845
+
1846
+ try:
1847
+ num_calib_samples_int = int(num_calib_samples)
1848
+ max_seq_len_int = int(max_seq_len)
1849
+ except ValueError as e:
1850
+ yield log(f"Error: Invalid number format for calibration settings. {e}")
1851
+ return
1852
+
1853
+ temp_dir = tempfile.mkdtemp()
1854
+ local_output_dir = Path(temp_dir) / f"{Path(model_id).name}-AWQ-{scheme}"
1855
+ yield log(f"ℹ️ Quantized model will be saved temporarily to: {local_output_dir.name}")
1856
+
1857
+ if hf_token:
1858
+ try:
1859
+ login(token=hf_token)
1860
+ yield log("✅ Logged in to Hugging Face Hub.")
1861
+ except Exception as e:
1862
+ yield log(f"⚠️ Hugging Face login failed: {e}")
1863
+ else:
1864
+ yield log("ℹ️ No HF token provided. You can still quantize public models and save locally.")
1865
+
1866
+ try:
1867
+ progress(0.1, desc="Building AWQ recipe...")
1868
+ yield log("🔧 Building AWQ recipe...")
1869
+
1870
+ ignore_patterns = ["lm_head"] if ignore_lm_head else None
1871
+ recipe = AWQModifier(
1872
+ targets="Linear",
1873
+ scheme=scheme,
1874
+ ignore=ignore_patterns,
1875
+ )
1876
+ yield log(f"Recipe:\n scheme = {scheme}\n ignore = {ignore_patterns or '[]'}")
1877
+
1878
+ except Exception as e:
1879
+ yield log(f"❌ Failed to build AWQ recipe: {e}")
1880
+ shutil.rmtree(temp_dir, ignore_errors=True)
1881
+ return
1882
+
1883
+ try:
1884
+ progress(0.25, desc="Running AWQ quantization...")
1885
+ yield log("🚀 Starting LLM Compressor `oneshot` run (no calibration dataset)...")
1886
+ yield log(f" • model = {model_id}")
1887
+ yield log(f" • num_calibration_samples = {num_calib_samples_int}")
1888
+ yield log(f" • max_seq_length = {max_seq_len_int}")
1889
+ yield log(f" • pipeline = {pipeline_mode}")
1890
+
1891
+ oneshot(
1892
+ model=model_id,
1893
+ dataset=None,
1894
+ recipe=recipe,
1895
+ output_dir=str(local_output_dir),
1896
+ max_seq_length=max_seq_len_int,
1897
+ num_calibration_samples=num_calib_samples_int,
1898
+ pipeline=pipeline_mode,
1899
+ trust_remote_code_model=True,
1900
+ device="cpu",
1901
+ )
1902
+
1903
+ progress(0.8, desc="Quantization complete. Preparing upload...")
1904
+ yield log("✅ AWQ quantization finished.")
1905
+
1906
+ except Exception as e:
1907
+ progress(1.0, desc="Error")
1908
+ yield log(f"❌ CRITICAL ERROR during oneshot:\n{traceback.format_exc()}")
1909
+ shutil.rmtree(temp_dir, ignore_errors=True)
1910
+ return
1911
+
1912
+ if upload_repo and hf_token:
1913
+ try:
1914
+ progress(0.9, desc="Uploading compressed model to Hugging Face Hub...")
1915
+ yield log(f"☁️ Uploading folder `{local_output_dir.name}` to repo `{upload_repo}`...")
1916
+
1917
+ api = HfApi(token=hf_token)
1918
+ api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True)
1919
+ api.upload_folder(
1920
+ folder_path=str(local_output_dir),
1921
+ repo_id=upload_repo,
1922
+ repo_type="model",
1923
+ )
1924
+
1925
+ hub_url = f"https://huggingface.co/{upload_repo}"
1926
+ yield log(f"✅ Upload complete. Model available at:\n{hub_url}")
1927
+
1928
+ except Exception as e:
1929
+ yield log(f"⚠️ Upload failed: {e}")
1930
+ else:
1931
+ yield log("ℹ️ No upload repo configured. Local files saved to temporary location.")
1932
+
1933
+ shutil.rmtree(temp_dir, ignore_errors=True)
1934
+ progress(1.0, desc="Done!")
1935
+ yield log("🎉 Done! AWQ compression finished successfully. Local temporary files cleaned up.")
1936
+
1937
+
1938
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1939
+ gr.Markdown("# 🚀 AutoTrain-Advanced & Quantization Hub")
1940
+ gr.Markdown("### Una plataforma unificada para Fine-Tuning, PEFT, TorchAO y AWQ Quantization.")
1941
 
1942
  with gr.Tab("1. Autenticación"):
1943
  gr.Markdown("#### Conecta tu cuenta de Hugging Face para guardar y cargar modelos.")
 
2253
  inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
2254
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
2255
  )
2256
+
2257
+ with gr.Tab("5. TorchAO Quantization"):
2258
+ gr.Markdown("## 🔥 TorchAO Quantizer")
2259
+ gr.Markdown("Cuantización eficiente usando `torchao`.")
2260
+ with gr.Row():
2261
+ ao_token = gr.Textbox(label="HF Token (si es diferente al principal)", type="password", placeholder="Opcional")
2262
+ ao_model_name = HuggingfaceHubSearch(label="🔍 Hub Model ID", placeholder="Search a model", search_type="model")
2263
+ ao_quant_type = gr.Dropdown(choices=list(MAP_QUANT_TYPE_TO_NAME.keys()), value="Int8WeightOnly", label="Tipo de Cuantización")
2264
+ ao_group_size = gr.Textbox(label="Group Size (opcional)", value="128")
2265
+ ao_custom_name = gr.Textbox(label="Nombre Personalizado (opcional)", value="")
2266
+ ao_public = gr.Checkbox(label="Hacer Público", value=True)
2267
+ ao_output = gr.Markdown()
2268
+ ao_btn = gr.Button("🚀 Cuantizar y Subir", variant="primary")
2269
+
2270
+ ao_btn.click(
2271
+ quantize_and_save_ao,
2272
+ inputs=[ao_model_name, ao_quant_type, ao_group_size, ao_custom_name, ao_public, hf_token_input],
2273
+ outputs=ao_output
2274
+ )
2275
+
2276
+ with gr.Tab("6. AWQ Quantization"):
2277
+ gr.Markdown("## 🧱 LLM Compressor – AWQ Quantizer")
2278
+ gr.Markdown("Cuantización AWQ usando `llmcompressor` (oneshot).")
2279
+ with gr.Row():
2280
+ with gr.Column():
2281
+ awq_token = gr.Textbox(label="HF Token (si es diferente al principal)", type="password", placeholder="Opcional")
2282
+ awq_model_id = gr.Textbox(label="Source Model ID", value="meta-llama/Llama-3.3-70B-Instruct")
2283
+ awq_scheme = gr.Dropdown(label="AWQ Scheme", choices=["W4A16", "W4A16_ASYM"], value="W4A16_ASYM")
2284
+ awq_ignore_head = gr.Checkbox(label="Ignore lm_head", value=True)
2285
+ awq_calib = gr.Number(label="Calibration Samples", value=128, precision=0)
2286
+ awq_seq_len = gr.Number(label="Max Sequence Length", value=2048, precision=0)
2287
+ awq_pipeline = gr.Dropdown(label="Pipeline Mode", choices=["sequential", "default"], value="sequential")
2288
+ awq_repo = gr.Textbox(label="Target HF Repo", placeholder="username/model-awq")
2289
+ awq_btn = gr.Button("Iniciar Compresión AWQ", variant="primary")
2290
+ with gr.Column():
2291
+ awq_logs = gr.Textbox(label="Logs del Proceso", lines=30, interactive=False)
2292
+
2293
+ awq_btn.click(
2294
+ run_awq_compression,
2295
+ inputs=[hf_token_input, awq_model_id, awq_scheme, awq_ignore_head, awq_calib, awq_seq_len, awq_pipeline, awq_repo],
2296
+ outputs=[awq_logs]
2297
+ )
2298
+
2299
+ with gr.Tab("7. Explicación del Código"):
2300
  gr.Markdown("""
2301
  ### 🧠 Explicación del Código y Mecanismos Avanzados
2302
  """)
 
2324
  * Task-Specific Heads: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`).
2325
  * Seq2Seq: For translation/summarization tasks, using `Seq2SeqTrainer`.
2326
  """)
2327
+ gr.Markdown("#### 4. QUANTIZATION (TorchAO & AWQ)")
2328
  gr.Markdown("""
2329
+ * **TorchAO**: PyTorch Native Quantization. Supports Int4, Int8, and Float8 quantization techniques directly integrated with the model loading process.
2330
+ * **AWQ (Activation-aware Weight Quantization)**: Uses `llmcompressor` in oneshot mode to protect salient weights based on activation magnitude, preserving performance at 4-bit.
 
 
 
 
2331
  """)
2332
  gr.Markdown("#### 5. OUTPUT & DEPLOYMENT")
2333
  gr.Markdown("""
 
2337
  """)
2338
 
2339
  if __name__ == "__main__":
2340
+ demo.queue(max_size=50).launch(debug=True, share=True)