Ignaciohhhhggfgjfrffd commited on
Commit
e3b8521
·
verified ·
1 Parent(s): ed44448

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -623
app.py CHANGED
@@ -49,7 +49,7 @@ from transformers import (
49
  PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM,
50
  DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor
51
  )
52
- from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training
53
  from trl import SFTTrainer, DPOTrainer
54
  from diffusers import (
55
  UNet2DConditionModel, DDPMScheduler, AutoencoderKL, DiffusionPipeline,
@@ -157,36 +157,44 @@ class DebiasingSFTTrainer(SFTTrainer):
157
  break
158
  return (loss, outputs) if return_outputs else loss
159
 
160
- @spaces.GPU()
161
- def _create_deduplicated_iterable_dataset(dataset, text_col, method, threshold=0.85, num_perm=128):
162
- lsh_state = MinHashLSH(threshold=threshold, num_perm=num_perm) if method == 'Semántica (MinHash)' else None
163
- seen_texts_state = set() if method == 'Exacta' else None
164
- def gen():
165
- if method == 'Exacta':
166
- for example in dataset:
167
- text = example.get(text_col, "")
168
- if text and isinstance(text, str):
169
- if text not in seen_texts_state:
170
- seen_texts_state.add(text)
171
- yield example
172
- else:
173
  yield example
174
- elif method == 'Semántica (MinHash)':
175
- for i, example in enumerate(dataset):
176
- text = example.get(text_col, "")
177
- if text and isinstance(text, str) and text.strip():
178
- m = MinHash(num_perm=num_perm)
179
- for d in text.split():
180
- m.update(d.encode('utf8'))
181
- if not lsh_state.query(m):
182
- lsh_state.insert(f"key_{i}", m)
183
- yield example
184
- else:
 
185
  yield example
186
- else:
187
- yield from dataset
188
- new_ds = IterableDataset.from_generator(gen)
189
- return new_ds
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  @spaces.GPU()
192
  def hf_login(token):
@@ -1834,594 +1842,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1834
  inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
1835
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1836
  )
1837
- if __name__ == "__main__":
1838
- demo.queue().launch(debug=True, share=True)
1839
- TRANSFORMERS_AVAILABLE = True
1840
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
1841
- def is_hub_repo_like_s1(s):
1842
- return "/" in s and not Path(s).exists()
1843
- def download_from_hf_s1(repo_id, filename, token=None):
1844
- token = token or os.environ.get("HF_TOKEN")
1845
- return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
1846
- class MediaTextDataset_s1(Dataset):
1847
- def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None):
1848
- self.is_hub = is_hub_repo_like_s1(source)
1849
- token = os.environ.get("HF_TOKEN")
1850
- if self.is_hub:
1851
- file_path = download_from_hf_s1(source, csv_name, token)
1852
- else:
1853
- file_path = Path(source) / csv_name
1854
- if not Path(file_path).exists():
1855
- alt = Path(str(file_path).replace(".csv", ".parquet"))
1856
- if alt.exists():
1857
- file_path = alt
1858
- else:
1859
- raise FileNotFoundError(f"Dataset file not found: {file_path}")
1860
- self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
1861
- if max_records:
1862
- self.df = self.df.head(max_records)
1863
- self.text_columns = text_columns or ["short_prompt", "long_prompt"]
1864
- def __len__(self):
1865
- return len(self.df)
1866
- def __getitem__(self, i):
1867
- rec = self.df.iloc[i]
1868
- out = {"text": {}}
1869
- for col in self.text_columns:
1870
- out["text"][col] = rec[col] if col in rec else ""
1871
- return out
1872
- def load_pipeline_auto_s1(base_model, dtype=torch.float16):
1873
- if "gemma" in base_model.lower():
1874
- if not TRANSFORMERS_AVAILABLE:
1875
- raise RuntimeError("Transformers not installed for LLM support.")
1876
- tokenizer = AutoTokenizer.from_pretrained(base_model)
1877
- model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
1878
- return {"model": model, "tokenizer": tokenizer}
1879
- else:
1880
- raise NotImplementedError("Only Gemma LLM supported in this script.")
1881
- def find_target_modules_s1(model):
1882
- candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
1883
- names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
1884
- targets = [n.split(".")[-1] for n in names if any(c in n for c in candidates)]
1885
- if not targets:
1886
- targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
1887
- return targets
1888
- def unwrap_batch_s1(batch, short_col, long_col):
1889
- if isinstance(batch, (list, tuple)):
1890
- ex = batch[0]
1891
- if "text" in ex:
1892
- return ex
1893
- if "short" in ex and "long" in ex:
1894
- return {"text": {short_col: ex.get("short",""), long_col: ex.get("long","")}}
1895
- return {"text": ex}
1896
- if isinstance(batch, dict):
1897
- first_elem = {}
1898
- is_batched = any(isinstance(v, (list, tuple, np.ndarray, torch.Tensor)) for v in batch.values())
1899
- if is_batched:
1900
- for k, v in batch.items():
1901
- try: first = v[0]
1902
- except Exception: first = v
1903
- first_elem[k] = first
1904
- if "text" in first_elem:
1905
- t = first_elem["text"]
1906
- if isinstance(t, (list, tuple)) and len(t) > 0:
1907
- return {"text": t[0] if isinstance(t[0], dict) else {short_col: t[0], long_col: ""}}
1908
- if isinstance(t, dict): return {"text": t}
1909
- return {"text": {short_col: str(t), long_col: ""}}
1910
- if ("short" in first_elem and "long" in first_elem) or (short_col in first_elem and long_col in first_elem):
1911
- s = first_elem.get(short_col, first_elem.get("short", ""))
1912
- l = first_elem.get(long_col, first_elem.get("long", ""))
1913
- return {"text": {short_col: str(s), long_col: str(l)}}
1914
- return {"text": {short_col: str(first_elem)}}
1915
- if "text" in batch and isinstance(batch["text"], dict):
1916
- return {"text": batch["text"]}
1917
- s = batch.get(short_col, batch.get("short", ""))
1918
- l = batch.get(long_col, batch.get("long", ""))
1919
- return {"text": {short_col: str(s), long_col: str(l)}}
1920
- return {"text": {short_col: str(batch), long_col: ""}}
1921
- def train_lora_stream_s1(base_model, dataset_src, csv_name, text_cols, output_dir,
1922
- epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1, num_workers=0,
1923
- max_train_records=None):
1924
- accelerator = accelerate.Accelerator()
1925
- pipe = load_pipeline_auto_s1(base_model)
1926
- model_obj = pipe["model"]
1927
- tokenizer = pipe["tokenizer"]
1928
- model_obj.train()
1929
- target_modules = find_target_modules_s1(model_obj)
1930
- lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
1931
- lora_module = get_peft_model(model_obj, lcfg)
1932
- dataset = MediaTextDataset_s1(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records)
1933
- loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
1934
- optimizer = torch.optim.AdamW(lora_module.parameters(), lr=lr)
1935
- lora_module, optimizer, loader = accelerator.prepare(lora_module, optimizer, loader)
1936
- total_steps = max(1, epochs * len(loader))
1937
- step_counter = 0
1938
- logs = []
1939
- yield "[DEBUG] Starting training loop...\n", 0.0
1940
- for ep in range(epochs):
1941
- yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / total_steps
1942
- for i, batch in enumerate(loader):
1943
- ex = unwrap_batch_s1(batch, text_cols[0], text_cols[1])
1944
- texts = ex.get("text", {})
1945
- short_text = str(texts.get(text_cols[0], "") or "")
1946
- long_text = str(texts.get(text_cols[1], "") or "")
1947
- enc = tokenizer(short_text, text_pair=long_text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
1948
- enc = {k: v.to(accelerator.device) for k, v in enc.items()}
1949
- enc["labels"] = enc["input_ids"].clone()
1950
- outputs = lora_module(**enc)
1951
- forward_loss = getattr(outputs, "loss", None)
1952
- if forward_loss is None:
1953
- logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
1954
- forward_loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), enc["labels"].view(-1), ignore_index=tokenizer.pad_token_id)
1955
- logs.append(f"[DEBUG] Step {step_counter}, forward_loss: {forward_loss.item():.6f}")
1956
- optimizer.zero_grad()
1957
- accelerator.backward(forward_loss)
1958
- optimizer.step()
1959
- step_counter += 1
1960
- yield "\n".join(logs[-10:]), step_counter / total_steps
1961
- Path(output_dir).mkdir(parents=True, exist_ok=True)
1962
- lora_module.save_pretrained(output_dir)
1963
- yield f"[INFO] ✅ LoRA saved to {output_dir}\n", 1.0
1964
- def upload_adapter_s1(local, repo_id):
1965
- token = os.environ.get("HF_TOKEN")
1966
- if not token:
1967
- raise RuntimeError("HF_TOKEN missing")
1968
- create_repo(repo_id, exist_ok=True)
1969
- upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
1970
- return f"https://huggingface.co/{repo_id}"
1971
- def run_ui_s1():
1972
- with gr.Blocks() as demo_s1:
1973
- gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Gemma LLM)")
1974
- with gr.Row():
1975
- base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
1976
- dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
1977
- csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
1978
- short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
1979
- long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
1980
- out = gr.Textbox(label="Output dir", value="./adapter_out")
1981
- repo = gr.Textbox(label="Upload HF repo (optional)", value="rahul7star/gemma-3-270m-ccebc0")
1982
- with gr.Row():
1983
- batch_size = gr.Number(value=1, label="Batch size")
1984
- num_workers = gr.Number(value=0, label="DataLoader num_workers")
1985
- r = gr.Number(value=8, label="LoRA rank")
1986
- a = gr.Number(value=16, label="LoRA alpha")
1987
- ep = gr.Number(value=1, label="Epochs")
1988
- lr = gr.Number(value=1e-4, label="Learning rate")
1989
- max_records = gr.Number(value=1000, label="Max training records")
1990
- logs = gr.Textbox(label="Logs (streaming)", lines=25)
1991
- def launch(bm, ds, csv, sc, lc, out_dir, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
1992
- gen = train_lora_stream_s1(bm, ds, csv, [sc, lc], out_dir, epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_), batch_size=int(batch), num_workers=int(num_w), max_train_records=int(max_rec))
1993
- for item in gen:
1994
- if isinstance(item, tuple):
1995
- text = item[0]
1996
- else:
1997
- text = item
1998
- yield text
1999
- if repo_:
2000
- link = upload_adapter_s1(out_dir, repo_)
2001
- yield f"[INFO] Uploaded to {link}\n"
2002
- btn = gr.Button("🚀 Start Training")
2003
- btn.click(fn=launch, inputs=[base_model, dataset, csvname, short_col, long_col, out, batch_size, num_workers, r, a, ep, lr, max_records, repo], outputs=[logs], queue=True)
2004
- return demo_s1
2005
- CHRONOEDIT_AVAILABLE = False
2006
- try:
2007
- from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
2008
- CHRONOEDIT_AVAILABLE = True
2009
- except Exception:
2010
- pass
2011
- QWENEDIT_AVAILABLE = False
2012
- try:
2013
- from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline
2014
- QWENEDIT_AVAILABLE = True
2015
- except Exception:
2016
- pass
2017
- BNB_AVAILABLE = False
2018
- try:
2019
- from transformers import BitsAndBytesConfig
2020
- BNB_AVAILABLE = True
2021
- except Exception:
2022
- BitsAndBytesConfig = None
2023
- XFORMERS_AVAILABLE = False
2024
- try:
2025
- import xformers
2026
- XFORMERS_AVAILABLE = True
2027
- except Exception:
2028
- pass
2029
- ADALORA_AVAILABLE = False
2030
- try:
2031
- from peft import AdaLoraConfig
2032
- ADALORA_AVAILABLE = True
2033
- except Exception:
2034
- AdaLoraConfig = None
2035
- IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
2036
- VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
2037
- def is_hub_repo_like_s2(s: str) -> bool:
2038
- return "/" in s and not Path(s).exists()
2039
- def download_from_hf_s2(repo_id: str, filename: str, token: str = None, repo_type: str = "dataset") -> str:
2040
- token = token or os.environ.get("HF_TOKEN")
2041
- return hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token, repo_type=repo_type)
2042
- def try_list_repo_files_s2(repo_id: str, repo_type: str = "dataset", token: str = None):
2043
- token = token or os.environ.get("HF_TOKEN")
2044
- try:
2045
- return list_repo_files(repo_id, token=token, repo_type=repo_type)
2046
- except Exception:
2047
- return []
2048
- def find_target_modules_s2(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")):
2049
- names = [n for n, _ in model.named_modules()]
2050
- selected = set()
2051
- for cand in candidates:
2052
- for n in names:
2053
- if cand in n:
2054
- selected.add(n.split(".")[-1])
2055
- if not selected:
2056
- return ["to_q", "to_k", "to_v", "to_out"]
2057
- return list(selected)
2058
- class MediaTextDataset_s2(Dataset):
2059
- def __init__(self, dataset_source: str, csv_name: str = "dataset.csv", max_frames: int = 5, image_size=(512,512), video_frame_size=(128,256), hub_token: str = None):
2060
- self.source = dataset_source
2061
- self.is_hub = is_hub_repo_like_s2(dataset_source)
2062
- self.max_frames = max_frames
2063
- self.image_size = image_size
2064
- self.video_frame_size = video_frame_size
2065
- self.hub_token = hub_token or os.environ.get("HF_TOKEN")
2066
- if self.is_hub:
2067
- try:
2068
- csv_local = download_from_hf_s2(self.source, csv_name, token=self.hub_token, repo_type="dataset")
2069
- except Exception:
2070
- alt = csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else csv_name + ".parquet"
2071
- csv_local = download_from_hf_s2(self.source, alt, token=self.hub_token, repo_type="dataset")
2072
- if str(csv_local).endswith(".parquet"):
2073
- df = pd.read_parquet(csv_local)
2074
- else:
2075
- df = pd.read_csv(csv_local)
2076
- self.df = df
2077
- self.root = None
2078
- else:
2079
- root = Path(dataset_source)
2080
- csv_path = root / csv_name
2081
- parquet_path = root / csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else root / (csv_name + ".parquet")
2082
- if csv_path.exists():
2083
- self.df = pd.read_csv(csv_path)
2084
- elif parquet_path.exists():
2085
- self.df = pd.read_parquet(parquet_path)
2086
- else:
2087
- p = root / csv_name
2088
- if p.exists():
2089
- if p.suffix.lower() == ".parquet":
2090
- self.df = pd.read_parquet(p)
2091
- else:
2092
- self.df = pd.read_csv(p)
2093
- else:
2094
- raise FileNotFoundError(f"Can't find {csv_name} in {dataset_source}")
2095
- self.root = root
2096
- self.image_transform = T.Compose([T.ToPILImage(), T.Resize(image_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
2097
- self.video_transform = T.Compose([T.ToPILImage(), T.Resize(video_frame_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
2098
- def __len__(self):
2099
- return len(self.df)
2100
- def _maybe_download_from_hub(self, file_name: str) -> str:
2101
- if self.root is not None:
2102
- p = self.root / file_name
2103
- if p.exists():
2104
- return str(p)
2105
- return download_from_hf_s2(self.source, file_name, token=self.hub_token, repo_type="dataset")
2106
- def _read_video_frames(self, path: str, num_frames: int):
2107
- video_frames, _, _ = torchvision.io.read_video(str(path), pts_unit='sec')
2108
- total = len(video_frames)
2109
- if total == 0:
2110
- C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1]
2111
- return torch.zeros((num_frames, C, H, W), dtype=torch.float32)
2112
- if total < num_frames:
2113
- idxs = list(range(total)) + [total-1]*(num_frames-total)
2114
- else:
2115
- idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist()
2116
- frames = []
2117
- for i in idxs:
2118
- arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i])
2119
- frames.append(self.video_transform(arr))
2120
- frames = torch.stack(frames, dim=0)
2121
- return frames
2122
- def __getitem__(self, idx):
2123
- rec = self.df.iloc[idx]
2124
- file_name = rec["file_name"]
2125
- caption = rec["text"]
2126
- if self.is_hub:
2127
- local_path = self._maybe_download_from_hub(file_name)
2128
- else:
2129
- local_path = str(Path(self.root) / file_name)
2130
- p = Path(local_path)
2131
- suffix = p.suffix.lower()
2132
- if suffix in IMAGE_EXTS:
2133
- img = torchvision.io.read_image(local_path)
2134
- if isinstance(img, torch.Tensor):
2135
- img = img.permute(1,2,0).numpy()
2136
- return {'type': 'image', 'image': self.image_transform(img), 'caption': caption, 'file_name': file_name}
2137
- elif suffix in VIDEO_EXTS:
2138
- frames = self._read_video_frames(local_path, self.max_frames)
2139
- return {'type': 'video', 'frames': frames, 'caption': caption, 'file_name': file_name}
2140
- else:
2141
- raise RuntimeError(f"Unsupported media type: {local_path}")
2142
- def load_pipeline_auto_s2(base_model_id: str, use_4bit: bool = False, bnb_config: object = None, torch_dtype=torch.float16):
2143
- low = base_model_id.lower()
2144
- is_chrono = "chrono" in low or "wan" in low or "video" in low
2145
- is_qwen = "qwen" in low or "qwenimage" in low
2146
- if is_chrono and CHRONOEDIT_AVAILABLE:
2147
- if use_4bit and bnb_config is not None:
2148
- pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
2149
- else:
2150
- pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
2151
- elif is_qwen and QWENEDIT_AVAILABLE:
2152
- pipe = QwenImageEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
2153
- else:
2154
- if use_4bit and BNB_AVAILABLE and bnb_config is not None:
2155
- pipe = DiffusionPipeline.from_pretrained(base_model_id, quantization_config=bnb_config, torch_dtype=torch.float16)
2156
- else:
2157
- pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
2158
- return pipe
2159
- def infer_target_for_task_s2(task_type: str, model_name: str) -> str:
2160
- low = model_name.lower()
2161
- if task_type == "prompt-lora" or "qwen" in low or "qwenedit" in low:
2162
- return "text_encoder"
2163
- if task_type == "text-video" or "chrono" in low or "wan" in low:
2164
- return "transformer"
2165
- return "unet"
2166
- def attach_lora_s2(pipe, adapter_target: str, r: int = 8, alpha: int = 16, dropout: float = 0.0, use_adalora: bool = False):
2167
- if adapter_target == "unet":
2168
- target_module = pipe.unet
2169
- attr = "unet"
2170
- elif adapter_target == "transformer":
2171
- target_module = pipe.transformer
2172
- attr = "transformer"
2173
- elif adapter_target == "text_encoder":
2174
- target_module = pipe.text_encoder
2175
- attr = "text_encoder"
2176
- else:
2177
- raise RuntimeError("Unknown adapter_target")
2178
- target_modules = find_target_modules_s2(target_module)
2179
- if use_adalora and ADALORA_AVAILABLE:
2180
- lora_config = AdaLoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, init_r=4, lora_dropout=dropout)
2181
- else:
2182
- lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=dropout, bias="none", task_type="SEQ_2_SEQ_LM")
2183
- peft_model = get_peft_model(target_module, lora_config)
2184
- setattr(pipe, attr, peft_model)
2185
- return pipe, attr
2186
- def train_lora_accelerate_s2(base_model_id: str, dataset_source: str, csv_name: str, task_type: str, adapter_target_override: str, output_dir: str, epochs: int = 1, batch_size: int = 1, lr: float = 1e-4, max_train_steps: int = None, lora_r: int = 8, lora_alpha: int = 16, use_4bit: bool = False, enable_xformers: bool = False, use_adalora: bool = False, gradient_accumulation_steps: int = 1, mixed_precision: str = None, save_every_steps: int = 200, max_frames: int = 5):
2187
- accelerator = accelerate.Accelerator(mixed_precision=mixed_precision or ("fp16" if torch.cuda.is_available() else "no"))
2188
- device = accelerator.device
2189
- bnb_conf = None
2190
- if use_4bit and BNB_AVAILABLE:
2191
- bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
2192
- pipe = load_pipeline_auto_s2(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
2193
- if enable_xformers:
2194
- try:
2195
- if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
2196
- pipe.enable_xformers_memory_efficient_attention()
2197
- elif hasattr(pipe, "enable_attention_slicing"):
2198
- pipe.enable_attention_slicing()
2199
- except Exception as e:
2200
- print(f"Could not enable xformers: {e}")
2201
- adapter_target = adapter_target_override if adapter_target_override else infer_target_for_task_s2(task_type, base_model_id)
2202
- pipe, attr = attach_lora_s2(pipe, adapter_target, r=lora_r, alpha=lora_alpha, dropout=0.0, use_adalora=use_adalora)
2203
- peft_module = getattr(pipe, attr)
2204
- dataset = MediaTextDataset_s2(dataset_source, csv_name=csv_name, max_frames=max_frames)
2205
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)
2206
- trainable_params = [p for n,p in peft_module.named_parameters() if p.requires_grad]
2207
- optimizer = torch.optim.AdamW(trainable_params, lr=lr)
2208
- peft_module, optimizer, dataloader = accelerator.prepare(peft_module, optimizer, dataloader)
2209
- logs = []
2210
- global_step = 0
2211
- loss_fn = nn.MSELoss()
2212
- timesteps = None
2213
- if hasattr(pipe, "scheduler"):
2214
- try:
2215
- pipe.scheduler.set_timesteps(50, device=device)
2216
- timesteps = pipe.scheduler.timesteps
2217
- except Exception:
2218
- pass
2219
- for epoch in range(int(epochs)):
2220
- pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
2221
- for batch in pbar:
2222
- example = batch[0]
2223
- if example["type"] == "image":
2224
- img = example["image"].unsqueeze(0).to(device)
2225
- caption = [example["caption"]]
2226
- if not hasattr(pipe, "encode_prompt"):
2227
- raise RuntimeError("Pipeline lacks encode_prompt - cannot encode prompts")
2228
- prompt_embeds, _ = pipe.encode_prompt(prompt=caption, negative_prompt=None, do_classifier_free_guidance=True, num_videos_per_prompt=1, device=device)
2229
- if not hasattr(pipe, "vae"):
2230
- raise RuntimeError("Pipeline lacks VAE - required for latent conversion")
2231
- with torch.no_grad():
2232
- latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor
2233
- noise = torch.randn_like(latents).to(device)
2234
- t = timesteps[torch.randint(0, len(timesteps), (1,)).item()].to(device) if timesteps is not None else torch.tensor(1, device=device)
2235
- noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
2236
- out = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds)
2237
- if hasattr(out, "sample"):
2238
- noise_pred = out.sample
2239
- elif isinstance(out, tuple):
2240
- noise_pred = out[0]
2241
- else:
2242
- noise_pred = out
2243
- loss = loss_fn(noise_pred, noise)
2244
- else:
2245
- if not CHRONOEDIT_AVAILABLE:
2246
- raise RuntimeError("ChronoEdit training requested but not installed in environment")
2247
- frames = example["frames"].unsqueeze(0).to(device)
2248
- frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist()
2249
- video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device)
2250
- latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device)
2251
- latents, condition = latents_out
2252
- noise = torch.randn_like(latents).to(device)
2253
- t = timesteps[torch.randint(0, len(timesteps), (1,)).item()].to(device)
2254
- noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
2255
- latent_model_input = torch.cat([noisy_latents, condition], dim=1)
2256
- out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]))
2257
- noise_pred = out[0]
2258
- loss = loss_fn(noise_pred, noise)
2259
- accelerator.backward(loss)
2260
- optimizer.step()
2261
- optimizer.zero_grad()
2262
- global_step += 1
2263
- logs.append(f"step {global_step} loss {loss.item():.6f}")
2264
- pbar.set_postfix({"loss": f"{loss.item():.6f}"})
2265
- if max_train_steps and global_step >= max_train_steps:
2266
- break
2267
- if global_step % save_every_steps == 0:
2268
- out_sub = Path(output_dir) / f"lora_step_{global_step}"
2269
- out_sub.mkdir(parents=True, exist_ok=True)
2270
- try:
2271
- peft_module.save_pretrained(str(out_sub))
2272
- except Exception:
2273
- torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt"))
2274
- if max_train_steps and global_step >= max_train_steps:
2275
- break
2276
- Path(output_dir).mkdir(parents=True, exist_ok=True)
2277
- try:
2278
- peft_module.save_pretrained(output_dir)
2279
- except Exception:
2280
- torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt"))
2281
- return output_dir, logs
2282
- def test_generation_load_and_run_s2(base_model_id: str, adapter_dir: str, adapter_target: str, prompt: str, use_4bit: bool = False):
2283
- bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") if use_4bit and BNB_AVAILABLE else None
2284
- pipe = load_pipeline_auto_s2(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
2285
- try:
2286
- if adapter_target == "unet" and hasattr(pipe, "unet"):
2287
- pipe.unet.load_adapter(adapter_dir)
2288
- elif adapter_target == "transformer" and hasattr(pipe, "transformer"):
2289
- pipe.transformer.load_adapter(adapter_dir)
2290
- elif adapter_target == "text_encoder" and hasattr(pipe, "text_encoder"):
2291
- pipe.text_encoder.load_adapter(adapter_dir)
2292
- except Exception as e:
2293
- print(f"Adapter load warning: {e}")
2294
- pipe.to(DEVICE)
2295
- out = pipe(prompt=prompt, num_inference_steps=8)
2296
- if hasattr(out, "images"):
2297
- return out.images[0]
2298
- elif hasattr(out, "frames"):
2299
- frames = out.frames[0]
2300
- return Image.fromarray((frames[-1] * 255).clip(0,255).astype("uint8"))
2301
- raise RuntimeError("No images/frames returned")
2302
- def upload_adapter_s2(local_dir: str, repo_id: str) -> str:
2303
- token = os.environ.get("HF_TOKEN")
2304
- if token is None:
2305
- raise RuntimeError("HF_TOKEN not set in environment for upload")
2306
- create_repo(repo_id, exist_ok=True)
2307
- upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", token=token)
2308
- return f"https://huggingface.co/{repo_id}"
2309
- def boost_info_text_s2(use_4bit: bool, enable_xformers: bool, mixed_precision: str, device_type: str):
2310
- lines = [f"Device: {device_type.upper()}"]
2311
- lines.append("4-bit QLoRA enabled: ~4x memory saving." if use_4bit and BNB_AVAILABLE else "QLoRA disabled.")
2312
- lines.append("xFormers/FlashAttention: memory-efficient attention enabled." if enable_xformers and XFORMERS_AVAILABLE else "xFormers disabled.")
2313
- lines.append(f"Mixed precision: {mixed_precision}" if mixed_precision else "Mixed precision: default.")
2314
- return "\n".join(lines)
2315
- def run_all_ui_s2(base_model_id: str, dataset_source: str, csv_name: str, task_type: str, adapter_target_override: str, lora_r: int, lora_alpha: int, epochs: int, batch_size: int, lr: float, max_train_steps: int, output_dir: str, upload_repo: str, use_4bit: bool, enable_xformers: bool, use_adalora: bool, grad_accum: int, mixed_precision: str, save_every_steps: int):
2316
- adapter_target = adapter_target_override if adapter_target_override else infer_target_for_task_s2(task_type, base_model_id)
2317
- try:
2318
- out_dir, logs = train_lora_accelerate_s2(base_model_id, dataset_source, csv_name, task_type, adapter_target, output_dir, epochs=epochs, lr=lr, max_train_steps=(max_train_steps if max_train_steps>0 else None), lora_r=lora_r, lora_alpha=lora_alpha, use_4bit=use_4bit, enable_xformers=enable_xformers, use_adalora=use_adalora, gradient_accumulation_steps=grad_accum, mixed_precision=(mixed_precision if mixed_precision != "none" else None), save_every_steps=save_every_steps)
2319
- except Exception as e:
2320
- return f"Training failed: {e}", None, None
2321
- link = None
2322
- if upload_repo:
2323
- try:
2324
- link = upload_adapter_s2(out_dir, upload_repo)
2325
- except Exception as e:
2326
- link = f"Upload failed: {e}"
2327
- try:
2328
- ds = MediaTextDataset_s2(dataset_source, csv_name=csv_name, max_frames=5)
2329
- test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard"
2330
- except Exception:
2331
- test_prompt = "A cat on a skateboard"
2332
- test_img = None
2333
- try:
2334
- test_img = test_generation_load_and_run_s2(base_model_id, out_dir, adapter_target, test_prompt, use_4bit=use_4bit)
2335
- except Exception as e:
2336
- print(f"Test gen failed: {e}")
2337
- return "\n".join(logs[-200:]), test_img, link
2338
- def build_ui_s2():
2339
- with gr.Blocks() as demo_s2:
2340
- gr.Markdown("# Universal LoRA Trainer — Quantization & Speedups (single-file)")
2341
- with gr.Row():
2342
- base_model = gr.Textbox(label="Base model id (Diffusers / ChronoEdit / Qwen)", value="runwayml/stable-diffusion-v1-5")
2343
- dataset_source = gr.Textbox(label="Dataset folder or HF repo (username/repo)", value="./dataset")
2344
- csv_name = gr.Textbox(label="CSV/Parquet filename", value="dataset.csv")
2345
- task_type = gr.Dropdown(label="Task type", choices=["text-image", "text-video", "prompt-lora"], value="text-image")
2346
- adapter_target_override = gr.Textbox(label="Adapter target override (leave blank for auto)", value="")
2347
- lora_r = gr.Slider(1, 64, value=8, step=1, label="LoRA rank (r)")
2348
- lora_alpha = gr.Slider(1, 128, value=16, step=1, label="LoRA alpha")
2349
- epochs = gr.Number(label="Epochs", value=1)
2350
- batch_size = gr.Number(label="Batch size (per device)", value=1)
2351
- lr = gr.Number(label="Learning rate", value=1e-4)
2352
- max_train_steps = gr.Number(label="Max train steps (0 = unlimited)", value=0)
2353
- save_every_steps = gr.Number(label="Save every steps", value=200)
2354
- output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out")
2355
- upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional)", value="")
2356
- with gr.Row():
2357
- use_4bit = gr.Checkbox(label="Enable 4-bit QLoRA (bitsandbytes)", value=False)
2358
- enable_xformers = gr.Checkbox(label="Enable xFormers / memory efficient attention", value=False)
2359
- use_adalora = gr.Checkbox(label="Use AdaLoRA (if available in peft)", value=False)
2360
- grad_accum = gr.Number(label="Gradient accumulation steps", value=1)
2361
- mixed_precision = gr.Radio(choices=["none", "fp16", "bf16"], value=("fp16" if torch.cuda.is_available() else "none"), label="Mixed precision")
2362
- boost_info = gr.Textbox(label="Expected boost / notes", value="", lines=6)
2363
- start_btn = gr.Button("Start Training")
2364
- with gr.Row():
2365
- logs = gr.Textbox(label="Training logs (tail)", lines=18)
2366
- sample_image = gr.Image(label="Sample generated frame after training")
2367
- upload_link = gr.Textbox(label="Upload link / status")
2368
- def on_start(base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, grad_accum_val, mixed_precision_val, save_every_steps):
2369
- boost_text = boost_info_text_s2(use_4bit_val, enable_xformers_val, mixed_precision_val, "gpu" if torch.cuda.is_available() else "cpu")
2370
- logs_out, sample, link = run_all_ui_s2(base_model, dataset_source, csv_name, task_type, adapter_target_override, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, int(grad_accum_val), mixed_precision_val, int(save_every_steps))
2371
- return boost_text + "\n\n" + logs_out, sample, link
2372
- start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit, enable_xformers, use_adalora, grad_accum, mixed_precision, save_every_steps], outputs=[boost_info, sample_image, upload_link])
2373
- return demo_s2
2374
- def run_all_ui_s3(base_model, dataset_src, csv_name, short_col, long_col, batch_size, num_workers, r, a, ep, lr, max_rec, repo_):
2375
- gen = train_lora_stream_s3(base_model, dataset_src, csv_name, [short_col, long_col], epochs=int(ep), lr=float(lr), r=int(r), alpha=int(a), batch_size=int(batch_size), num_workers=int(num_workers), max_train_records=int(max_rec))
2376
- for item in gen:
2377
- yield item
2378
- HF_TOKEN = os.environ.get("HF_TOKEN")
2379
- if not repo_ or not HF_TOKEN:
2380
- raise ValueError("HF repo ID and HF_TOKEN required for upload.")
2381
- repo_ = repo_.strip()
2382
- create_repo(repo_, repo_type="model", exist_ok=True, token=HF_TOKEN)
2383
- with tempfile.TemporaryDirectory() as tmp_dir:
2384
- lora_module.save_pretrained(tmp_dir)
2385
- upload_folder(folder_path=tmp_dir, repo_id=repo_, repo_type="model", token=HF_TOKEN)
2386
- link = f"https://huggingface.co/{repo_}"
2387
- yield "\n".join(logs) + f"\n[INFO] ✅ Uploaded successfully: {link}\n", link
2388
- def run_ui_s3_final():
2389
- with gr.Blocks() as demo_s3_final:
2390
- gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer & Inference")
2391
- with gr.Row():
2392
- base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
2393
- dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
2394
- csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
2395
- short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
2396
- long_col = gr.Textbox(label="Long prompt column", value="long_col")
2397
- repo = gr.Textbox(label="HF repo to upload LoRA", value="rahul7star/gemma-3-270m-ccebc0")
2398
- with gr.Row():
2399
- batch_size = gr.Number(value=1, label="Batch size")
2400
- num_workers = gr.Number(value=0, label="DataLoader num_workers")
2401
- r = gr.Number(value=8, label="LoRA rank")
2402
- a = gr.Number(value=16, label="LoRA alpha")
2403
- ep = gr.Number(value=1, label="Epochs")
2404
- lr = gr.Number(value=1e-4, label="Learning rate")
2405
- max_records = gr.Number(value=1000, label="Max training records")
2406
- logs = gr.Textbox(label="Logs (streaming)", lines=25)
2407
- btn = gr.Button("🚀 Start Training")
2408
- btn.click(fn=run_all_ui_s3,
2409
- inputs=[base_model, dataset, csvname, short_col, long_col, batch_size, num_workers, r, a, ep, lr, max_records, repo],
2410
- outputs=[logs],
2411
- queue=True)
2412
- with gr.Tab("Inference (CPU)"):
2413
- inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
2414
- inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
2415
- short_prompt = gr.Textbox(label="Short prompt")
2416
- long_prompt_out = gr.Textbox(label="Generated long prompt", lines=5)
2417
- inf_btn = gr.Button("📝 Generate Long Prompt")
2418
- inf_btn.click(fn=generate_long_prompt_cpu_s3,
2419
- inputs=[inf_base_model, inf_lora_repo, short_prompt],
2420
- outputs=[long_prompt_out])
2421
- with gr.Tab("Code Explain"):
2422
- explain_md = gr.Markdown("""
2423
- ### Universal LoRA Trainer & Inference - Code Explanation
2424
- #### 1. CORE MECHANISMS
2425
  * **PEFT/LoRA**: Parameter-Efficient Fine-Tuning. Only low-rank matrices ($A$ and $B$) are trained for low-rank updates ($W' = W + B A$). This drastically reduces trainable parameters.
2426
  * **QLoRA (4-bit)**: Loads the base model weights in 4-bit precision (NF4 with double quantization) using `bitsandbytes`, massively reducing VRAM usage while training LoRA adapters.
2427
  * **Accelerator**: Manages device placement (CPU/GPU), mixed precision (`fp16`/`bf16`), and gradient accumulation for stable large-batch training simulation.
@@ -2429,27 +1855,37 @@ def run_ui_s3_final():
2429
  * **Gradient Accumulation**: Simulates larger batch sizes by accumulating gradients over several forward/backward passes before an optimization step.
2430
  * **Gradient Clipping**: Limits the maximum norm of the gradients (`max_grad_norm`) to prevent exploding gradients during training.
2431
  * **Memory Optimization**: Optional use of `xFormers` (FlashAttention or memory-efficient attention) to reduce memory footprint and speed up training on compatible GPUs.
2432
- #### 2. DATA PROCESSING & AUGMENTATION
 
 
2433
  * **Streaming Datasets**: Uses `datasets` streaming mode to handle very large datasets without loading all into RAM.
2434
  * **Data Cleaning**: Removes HTML tags, normalizes whitespace, redacts PII, and removes URLs/emails.
2435
  * **Advanced Filtering**: Includes optional filters for text length, word repetition, language detection, and basic toxicity detection (via `unitary/toxic-bert`).
2436
  * **Data Augmentation**: Supports **Back-Translation (BT)** for introducing paraphrasing variations and **Counterfactual Data Augmentation (CDA)** for controlled bias testing (e.g., swapping gendered pronouns).
2437
  * **Synthetic Data Generation**: Uses a specified LLM to generate new training examples based on an initial prompt template.
2438
  * **Deduplication**: Implements both **Exact** and **Semantic (MinHash LSH)** deduplication to prevent data contamination during iterative fine-tuning.
2439
- #### 3. TRAINING MODES
 
 
2440
  * **SFT (Supervised Fine-Tuning)**: Standard fine-tuning, supports **Conversation** and **Reasoning/Tool Use (CoT)** formatting styles.
2441
  * **DPO (Direct Preference Optimization)**: Trains directly on preference pairs (chosen vs. rejected), using the `trl` library.
2442
  * **Task-Specific Heads**: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`).
2443
  * **Seq2Seq**: For translation/summarization tasks, using `Seq2SeqTrainer`.
2444
  * **Diffusion (Text-to-Image/DreamBooth)**: Fine-tunes the UNet (and optionally Text Encoder) using LoRA for image generation tasks, with custom image/video data handling.
2445
- #### 4. MODEL INITIALIZATION
 
 
2446
  * **Model From Scratch**: Allows initializing a model (e.g., Llama, Mistral) from a config rather than a pre-trained checkpoint, with optional auto-configuration based on expected training scale.
2447
  * **Multi-Adapter Merging**: Advanced feature to combine multiple existing LoRA adapters into a single, new adapter using weighted averaging (`slerp`, `linear`, etc.).
2448
- #### 5. OUTPUT & DEPLOYMENT
 
 
2449
  * **Hugging Face Hub Integration**: All trained artifacts (full model/LoRA adapter) are automatically pushed to a specified repository on the HF Hub using the provided token.
2450
  * **Model Card Generation**: Automatically generates a `README.md` detailing training parameters and model provenance.
2451
  * **Inference Tabs**: Separate UI for testing the trained LoRA adapter on CPU (for Gemma/LoRA) or various pipeline modes on GPU.
2452
  """)
2453
- return demo_s3
 
 
2454
  if __name__ == "__main__":
2455
- run_ui().launch(debug=True)
 
49
  PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM,
50
  DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor
51
  )
52
+ from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training, AdaLoraConfig
53
  from trl import SFTTrainer, DPOTrainer
54
  from diffusers import (
55
  UNet2DConditionModel, DDPMScheduler, AutoencoderKL, DiffusionPipeline,
 
157
  break
158
  return (loss, outputs) if return_outputs else loss
159
 
160
+ def _deduplication_generator(dataset, text_col, method, threshold, num_perm):
161
+ if method == 'Exacta':
162
+ seen_texts = set()
163
+ for example in dataset:
164
+ text = example.get(text_col, "")
165
+ if text and isinstance(text, str):
166
+ if text not in seen_texts:
167
+ seen_texts.add(text)
 
 
 
 
 
168
  yield example
169
+ else:
170
+ yield example
171
+ elif method == 'Semántica (MinHash)':
172
+ lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
173
+ for i, example in enumerate(dataset):
174
+ text = example.get(text_col, "")
175
+ if text and isinstance(text, str) and text.strip():
176
+ m = MinHash(num_perm=num_perm)
177
+ for d in text.split():
178
+ m.update(d.encode('utf8'))
179
+ if not lsh.query(m):
180
+ lsh.insert(f"key_{i}", m)
181
  yield example
182
+ else:
183
+ yield example
184
+ else:
185
+ yield from dataset
186
+
187
+ def _create_deduplicated_iterable_dataset(dataset, text_col, method, threshold=0.85, num_perm=128):
188
+ return IterableDataset.from_generator(
189
+ _deduplication_generator,
190
+ gen_kwargs={
191
+ "dataset": dataset,
192
+ "text_col": text_col,
193
+ "method": method,
194
+ "threshold": threshold,
195
+ "num_perm": num_perm,
196
+ }
197
+ )
198
 
199
  @spaces.GPU()
200
  def hf_login(token):
 
1842
  inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
1843
  outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
1844
  )
1845
+ with gr.Tab("5. Explicación del Código y Mecanismos Avanzados"):
1846
+ gr.Markdown("""
1847
+ ### 🧠 Explicación del Código y Mecanismos Avanzados
1848
+ """)
1849
+ gr.Markdown("#### 1. CORE MECHANISMS")
1850
+ gr.Markdown("""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1851
  * **PEFT/LoRA**: Parameter-Efficient Fine-Tuning. Only low-rank matrices ($A$ and $B$) are trained for low-rank updates ($W' = W + B A$). This drastically reduces trainable parameters.
1852
  * **QLoRA (4-bit)**: Loads the base model weights in 4-bit precision (NF4 with double quantization) using `bitsandbytes`, massively reducing VRAM usage while training LoRA adapters.
1853
  * **Accelerator**: Manages device placement (CPU/GPU), mixed precision (`fp16`/`bf16`), and gradient accumulation for stable large-batch training simulation.
 
1855
  * **Gradient Accumulation**: Simulates larger batch sizes by accumulating gradients over several forward/backward passes before an optimization step.
1856
  * **Gradient Clipping**: Limits the maximum norm of the gradients (`max_grad_norm`) to prevent exploding gradients during training.
1857
  * **Memory Optimization**: Optional use of `xFormers` (FlashAttention or memory-efficient attention) to reduce memory footprint and speed up training on compatible GPUs.
1858
+ """)
1859
+ gr.Markdown("#### 2. DATA PROCESSING & AUGMENTATION")
1860
+ gr.Markdown("""
1861
  * **Streaming Datasets**: Uses `datasets` streaming mode to handle very large datasets without loading all into RAM.
1862
  * **Data Cleaning**: Removes HTML tags, normalizes whitespace, redacts PII, and removes URLs/emails.
1863
  * **Advanced Filtering**: Includes optional filters for text length, word repetition, language detection, and basic toxicity detection (via `unitary/toxic-bert`).
1864
  * **Data Augmentation**: Supports **Back-Translation (BT)** for introducing paraphrasing variations and **Counterfactual Data Augmentation (CDA)** for controlled bias testing (e.g., swapping gendered pronouns).
1865
  * **Synthetic Data Generation**: Uses a specified LLM to generate new training examples based on an initial prompt template.
1866
  * **Deduplication**: Implements both **Exact** and **Semantic (MinHash LSH)** deduplication to prevent data contamination during iterative fine-tuning.
1867
+ """)
1868
+ gr.Markdown("#### 3. TRAINING MODES")
1869
+ gr.Markdown("""
1870
  * **SFT (Supervised Fine-Tuning)**: Standard fine-tuning, supports **Conversation** and **Reasoning/Tool Use (CoT)** formatting styles.
1871
  * **DPO (Direct Preference Optimization)**: Trains directly on preference pairs (chosen vs. rejected), using the `trl` library.
1872
  * **Task-Specific Heads**: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`).
1873
  * **Seq2Seq**: For translation/summarization tasks, using `Seq2SeqTrainer`.
1874
  * **Diffusion (Text-to-Image/DreamBooth)**: Fine-tunes the UNet (and optionally Text Encoder) using LoRA for image generation tasks, with custom image/video data handling.
1875
+ """)
1876
+ gr.Markdown("#### 4. MODEL INITIALIZATION")
1877
+ gr.Markdown("""
1878
  * **Model From Scratch**: Allows initializing a model (e.g., Llama, Mistral) from a config rather than a pre-trained checkpoint, with optional auto-configuration based on expected training scale.
1879
  * **Multi-Adapter Merging**: Advanced feature to combine multiple existing LoRA adapters into a single, new adapter using weighted averaging (`slerp`, `linear`, etc.).
1880
+ """)
1881
+ gr.Markdown("#### 5. OUTPUT & DEPLOYMENT")
1882
+ gr.Markdown("""
1883
  * **Hugging Face Hub Integration**: All trained artifacts (full model/LoRA adapter) are automatically pushed to a specified repository on the HF Hub using the provided token.
1884
  * **Model Card Generation**: Automatically generates a `README.md` detailing training parameters and model provenance.
1885
  * **Inference Tabs**: Separate UI for testing the trained LoRA adapter on CPU (for Gemma/LoRA) or various pipeline modes on GPU.
1886
  """)
1887
+ gr.Markdown("### 💡 Hardware Fallback")
1888
+ gr.Markdown(f"If CUDA/GPU is unavailable, the system defaults to CPU: **{device.upper()}**. Training and inference on CPU will be significantly slower, especially for large models or Diffusers.")
1889
+
1890
  if __name__ == "__main__":
1891
+ demo.queue().launch(debug=True, share=True)