Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -49,7 +49,7 @@ from transformers import (
|
|
| 49 |
PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM,
|
| 50 |
DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor
|
| 51 |
)
|
| 52 |
-
from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training
|
| 53 |
from trl import SFTTrainer, DPOTrainer
|
| 54 |
from diffusers import (
|
| 55 |
UNet2DConditionModel, DDPMScheduler, AutoencoderKL, DiffusionPipeline,
|
|
@@ -157,36 +157,44 @@ class DebiasingSFTTrainer(SFTTrainer):
|
|
| 157 |
break
|
| 158 |
return (loss, outputs) if return_outputs else loss
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
if text and isinstance(text, str):
|
| 169 |
-
if text not in seen_texts_state:
|
| 170 |
-
seen_texts_state.add(text)
|
| 171 |
-
yield example
|
| 172 |
-
else:
|
| 173 |
yield example
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
| 185 |
yield example
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
@spaces.GPU()
|
| 192 |
def hf_login(token):
|
|
@@ -1834,594 +1842,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1834 |
inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
|
| 1835 |
outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
|
| 1836 |
)
|
| 1837 |
-
|
| 1838 |
-
|
| 1839 |
-
|
| 1840 |
-
|
| 1841 |
-
|
| 1842 |
-
|
| 1843 |
-
def download_from_hf_s1(repo_id, filename, token=None):
|
| 1844 |
-
token = token or os.environ.get("HF_TOKEN")
|
| 1845 |
-
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset", token=token)
|
| 1846 |
-
class MediaTextDataset_s1(Dataset):
|
| 1847 |
-
def __init__(self, source, csv_name="dataset.csv", text_columns=None, max_records=None):
|
| 1848 |
-
self.is_hub = is_hub_repo_like_s1(source)
|
| 1849 |
-
token = os.environ.get("HF_TOKEN")
|
| 1850 |
-
if self.is_hub:
|
| 1851 |
-
file_path = download_from_hf_s1(source, csv_name, token)
|
| 1852 |
-
else:
|
| 1853 |
-
file_path = Path(source) / csv_name
|
| 1854 |
-
if not Path(file_path).exists():
|
| 1855 |
-
alt = Path(str(file_path).replace(".csv", ".parquet"))
|
| 1856 |
-
if alt.exists():
|
| 1857 |
-
file_path = alt
|
| 1858 |
-
else:
|
| 1859 |
-
raise FileNotFoundError(f"Dataset file not found: {file_path}")
|
| 1860 |
-
self.df = pd.read_parquet(file_path) if str(file_path).endswith(".parquet") else pd.read_csv(file_path)
|
| 1861 |
-
if max_records:
|
| 1862 |
-
self.df = self.df.head(max_records)
|
| 1863 |
-
self.text_columns = text_columns or ["short_prompt", "long_prompt"]
|
| 1864 |
-
def __len__(self):
|
| 1865 |
-
return len(self.df)
|
| 1866 |
-
def __getitem__(self, i):
|
| 1867 |
-
rec = self.df.iloc[i]
|
| 1868 |
-
out = {"text": {}}
|
| 1869 |
-
for col in self.text_columns:
|
| 1870 |
-
out["text"][col] = rec[col] if col in rec else ""
|
| 1871 |
-
return out
|
| 1872 |
-
def load_pipeline_auto_s1(base_model, dtype=torch.float16):
|
| 1873 |
-
if "gemma" in base_model.lower():
|
| 1874 |
-
if not TRANSFORMERS_AVAILABLE:
|
| 1875 |
-
raise RuntimeError("Transformers not installed for LLM support.")
|
| 1876 |
-
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 1877 |
-
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=dtype)
|
| 1878 |
-
return {"model": model, "tokenizer": tokenizer}
|
| 1879 |
-
else:
|
| 1880 |
-
raise NotImplementedError("Only Gemma LLM supported in this script.")
|
| 1881 |
-
def find_target_modules_s1(model):
|
| 1882 |
-
candidates = ["q_proj", "k_proj", "v_proj", "out_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
| 1883 |
-
names = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
| 1884 |
-
targets = [n.split(".")[-1] for n in names if any(c in n for c in candidates)]
|
| 1885 |
-
if not targets:
|
| 1886 |
-
targets = [n.split(".")[-1] for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
| 1887 |
-
return targets
|
| 1888 |
-
def unwrap_batch_s1(batch, short_col, long_col):
|
| 1889 |
-
if isinstance(batch, (list, tuple)):
|
| 1890 |
-
ex = batch[0]
|
| 1891 |
-
if "text" in ex:
|
| 1892 |
-
return ex
|
| 1893 |
-
if "short" in ex and "long" in ex:
|
| 1894 |
-
return {"text": {short_col: ex.get("short",""), long_col: ex.get("long","")}}
|
| 1895 |
-
return {"text": ex}
|
| 1896 |
-
if isinstance(batch, dict):
|
| 1897 |
-
first_elem = {}
|
| 1898 |
-
is_batched = any(isinstance(v, (list, tuple, np.ndarray, torch.Tensor)) for v in batch.values())
|
| 1899 |
-
if is_batched:
|
| 1900 |
-
for k, v in batch.items():
|
| 1901 |
-
try: first = v[0]
|
| 1902 |
-
except Exception: first = v
|
| 1903 |
-
first_elem[k] = first
|
| 1904 |
-
if "text" in first_elem:
|
| 1905 |
-
t = first_elem["text"]
|
| 1906 |
-
if isinstance(t, (list, tuple)) and len(t) > 0:
|
| 1907 |
-
return {"text": t[0] if isinstance(t[0], dict) else {short_col: t[0], long_col: ""}}
|
| 1908 |
-
if isinstance(t, dict): return {"text": t}
|
| 1909 |
-
return {"text": {short_col: str(t), long_col: ""}}
|
| 1910 |
-
if ("short" in first_elem and "long" in first_elem) or (short_col in first_elem and long_col in first_elem):
|
| 1911 |
-
s = first_elem.get(short_col, first_elem.get("short", ""))
|
| 1912 |
-
l = first_elem.get(long_col, first_elem.get("long", ""))
|
| 1913 |
-
return {"text": {short_col: str(s), long_col: str(l)}}
|
| 1914 |
-
return {"text": {short_col: str(first_elem)}}
|
| 1915 |
-
if "text" in batch and isinstance(batch["text"], dict):
|
| 1916 |
-
return {"text": batch["text"]}
|
| 1917 |
-
s = batch.get(short_col, batch.get("short", ""))
|
| 1918 |
-
l = batch.get(long_col, batch.get("long", ""))
|
| 1919 |
-
return {"text": {short_col: str(s), long_col: str(l)}}
|
| 1920 |
-
return {"text": {short_col: str(batch), long_col: ""}}
|
| 1921 |
-
def train_lora_stream_s1(base_model, dataset_src, csv_name, text_cols, output_dir,
|
| 1922 |
-
epochs=1, lr=1e-4, r=8, alpha=16, batch_size=1, num_workers=0,
|
| 1923 |
-
max_train_records=None):
|
| 1924 |
-
accelerator = accelerate.Accelerator()
|
| 1925 |
-
pipe = load_pipeline_auto_s1(base_model)
|
| 1926 |
-
model_obj = pipe["model"]
|
| 1927 |
-
tokenizer = pipe["tokenizer"]
|
| 1928 |
-
model_obj.train()
|
| 1929 |
-
target_modules = find_target_modules_s1(model_obj)
|
| 1930 |
-
lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
|
| 1931 |
-
lora_module = get_peft_model(model_obj, lcfg)
|
| 1932 |
-
dataset = MediaTextDataset_s1(dataset_src, csv_name, text_columns=text_cols, max_records=max_train_records)
|
| 1933 |
-
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
| 1934 |
-
optimizer = torch.optim.AdamW(lora_module.parameters(), lr=lr)
|
| 1935 |
-
lora_module, optimizer, loader = accelerator.prepare(lora_module, optimizer, loader)
|
| 1936 |
-
total_steps = max(1, epochs * len(loader))
|
| 1937 |
-
step_counter = 0
|
| 1938 |
-
logs = []
|
| 1939 |
-
yield "[DEBUG] Starting training loop...\n", 0.0
|
| 1940 |
-
for ep in range(epochs):
|
| 1941 |
-
yield f"[DEBUG] Epoch {ep+1}/{epochs}\n", step_counter / total_steps
|
| 1942 |
-
for i, batch in enumerate(loader):
|
| 1943 |
-
ex = unwrap_batch_s1(batch, text_cols[0], text_cols[1])
|
| 1944 |
-
texts = ex.get("text", {})
|
| 1945 |
-
short_text = str(texts.get(text_cols[0], "") or "")
|
| 1946 |
-
long_text = str(texts.get(text_cols[1], "") or "")
|
| 1947 |
-
enc = tokenizer(short_text, text_pair=long_text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
|
| 1948 |
-
enc = {k: v.to(accelerator.device) for k, v in enc.items()}
|
| 1949 |
-
enc["labels"] = enc["input_ids"].clone()
|
| 1950 |
-
outputs = lora_module(**enc)
|
| 1951 |
-
forward_loss = getattr(outputs, "loss", None)
|
| 1952 |
-
if forward_loss is None:
|
| 1953 |
-
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
|
| 1954 |
-
forward_loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), enc["labels"].view(-1), ignore_index=tokenizer.pad_token_id)
|
| 1955 |
-
logs.append(f"[DEBUG] Step {step_counter}, forward_loss: {forward_loss.item():.6f}")
|
| 1956 |
-
optimizer.zero_grad()
|
| 1957 |
-
accelerator.backward(forward_loss)
|
| 1958 |
-
optimizer.step()
|
| 1959 |
-
step_counter += 1
|
| 1960 |
-
yield "\n".join(logs[-10:]), step_counter / total_steps
|
| 1961 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 1962 |
-
lora_module.save_pretrained(output_dir)
|
| 1963 |
-
yield f"[INFO] ✅ LoRA saved to {output_dir}\n", 1.0
|
| 1964 |
-
def upload_adapter_s1(local, repo_id):
|
| 1965 |
-
token = os.environ.get("HF_TOKEN")
|
| 1966 |
-
if not token:
|
| 1967 |
-
raise RuntimeError("HF_TOKEN missing")
|
| 1968 |
-
create_repo(repo_id, exist_ok=True)
|
| 1969 |
-
upload_folder(local, repo_id=repo_id, repo_type="model", token=token)
|
| 1970 |
-
return f"https://huggingface.co/{repo_id}"
|
| 1971 |
-
def run_ui_s1():
|
| 1972 |
-
with gr.Blocks() as demo_s1:
|
| 1973 |
-
gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer (Gemma LLM)")
|
| 1974 |
-
with gr.Row():
|
| 1975 |
-
base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
|
| 1976 |
-
dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
|
| 1977 |
-
csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
|
| 1978 |
-
short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
|
| 1979 |
-
long_col = gr.Textbox(label="Long prompt column", value="long_prompt")
|
| 1980 |
-
out = gr.Textbox(label="Output dir", value="./adapter_out")
|
| 1981 |
-
repo = gr.Textbox(label="Upload HF repo (optional)", value="rahul7star/gemma-3-270m-ccebc0")
|
| 1982 |
-
with gr.Row():
|
| 1983 |
-
batch_size = gr.Number(value=1, label="Batch size")
|
| 1984 |
-
num_workers = gr.Number(value=0, label="DataLoader num_workers")
|
| 1985 |
-
r = gr.Number(value=8, label="LoRA rank")
|
| 1986 |
-
a = gr.Number(value=16, label="LoRA alpha")
|
| 1987 |
-
ep = gr.Number(value=1, label="Epochs")
|
| 1988 |
-
lr = gr.Number(value=1e-4, label="Learning rate")
|
| 1989 |
-
max_records = gr.Number(value=1000, label="Max training records")
|
| 1990 |
-
logs = gr.Textbox(label="Logs (streaming)", lines=25)
|
| 1991 |
-
def launch(bm, ds, csv, sc, lc, out_dir, batch, num_w, r_, a_, ep_, lr_, max_rec, repo_):
|
| 1992 |
-
gen = train_lora_stream_s1(bm, ds, csv, [sc, lc], out_dir, epochs=int(ep_), lr=float(lr_), r=int(r_), alpha=int(a_), batch_size=int(batch), num_workers=int(num_w), max_train_records=int(max_rec))
|
| 1993 |
-
for item in gen:
|
| 1994 |
-
if isinstance(item, tuple):
|
| 1995 |
-
text = item[0]
|
| 1996 |
-
else:
|
| 1997 |
-
text = item
|
| 1998 |
-
yield text
|
| 1999 |
-
if repo_:
|
| 2000 |
-
link = upload_adapter_s1(out_dir, repo_)
|
| 2001 |
-
yield f"[INFO] Uploaded to {link}\n"
|
| 2002 |
-
btn = gr.Button("🚀 Start Training")
|
| 2003 |
-
btn.click(fn=launch, inputs=[base_model, dataset, csvname, short_col, long_col, out, batch_size, num_workers, r, a, ep, lr, max_records, repo], outputs=[logs], queue=True)
|
| 2004 |
-
return demo_s1
|
| 2005 |
-
CHRONOEDIT_AVAILABLE = False
|
| 2006 |
-
try:
|
| 2007 |
-
from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline
|
| 2008 |
-
CHRONOEDIT_AVAILABLE = True
|
| 2009 |
-
except Exception:
|
| 2010 |
-
pass
|
| 2011 |
-
QWENEDIT_AVAILABLE = False
|
| 2012 |
-
try:
|
| 2013 |
-
from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPipeline
|
| 2014 |
-
QWENEDIT_AVAILABLE = True
|
| 2015 |
-
except Exception:
|
| 2016 |
-
pass
|
| 2017 |
-
BNB_AVAILABLE = False
|
| 2018 |
-
try:
|
| 2019 |
-
from transformers import BitsAndBytesConfig
|
| 2020 |
-
BNB_AVAILABLE = True
|
| 2021 |
-
except Exception:
|
| 2022 |
-
BitsAndBytesConfig = None
|
| 2023 |
-
XFORMERS_AVAILABLE = False
|
| 2024 |
-
try:
|
| 2025 |
-
import xformers
|
| 2026 |
-
XFORMERS_AVAILABLE = True
|
| 2027 |
-
except Exception:
|
| 2028 |
-
pass
|
| 2029 |
-
ADALORA_AVAILABLE = False
|
| 2030 |
-
try:
|
| 2031 |
-
from peft import AdaLoraConfig
|
| 2032 |
-
ADALORA_AVAILABLE = True
|
| 2033 |
-
except Exception:
|
| 2034 |
-
AdaLoraConfig = None
|
| 2035 |
-
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"}
|
| 2036 |
-
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv"}
|
| 2037 |
-
def is_hub_repo_like_s2(s: str) -> bool:
|
| 2038 |
-
return "/" in s and not Path(s).exists()
|
| 2039 |
-
def download_from_hf_s2(repo_id: str, filename: str, token: str = None, repo_type: str = "dataset") -> str:
|
| 2040 |
-
token = token or os.environ.get("HF_TOKEN")
|
| 2041 |
-
return hf_hub_download(repo_id=repo_id, filename=filename, use_auth_token=token, repo_type=repo_type)
|
| 2042 |
-
def try_list_repo_files_s2(repo_id: str, repo_type: str = "dataset", token: str = None):
|
| 2043 |
-
token = token or os.environ.get("HF_TOKEN")
|
| 2044 |
-
try:
|
| 2045 |
-
return list_repo_files(repo_id, token=token, repo_type=repo_type)
|
| 2046 |
-
except Exception:
|
| 2047 |
-
return []
|
| 2048 |
-
def find_target_modules_s2(model, candidates=("q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_k", "to_v", "proj_out", "to_out")):
|
| 2049 |
-
names = [n for n, _ in model.named_modules()]
|
| 2050 |
-
selected = set()
|
| 2051 |
-
for cand in candidates:
|
| 2052 |
-
for n in names:
|
| 2053 |
-
if cand in n:
|
| 2054 |
-
selected.add(n.split(".")[-1])
|
| 2055 |
-
if not selected:
|
| 2056 |
-
return ["to_q", "to_k", "to_v", "to_out"]
|
| 2057 |
-
return list(selected)
|
| 2058 |
-
class MediaTextDataset_s2(Dataset):
|
| 2059 |
-
def __init__(self, dataset_source: str, csv_name: str = "dataset.csv", max_frames: int = 5, image_size=(512,512), video_frame_size=(128,256), hub_token: str = None):
|
| 2060 |
-
self.source = dataset_source
|
| 2061 |
-
self.is_hub = is_hub_repo_like_s2(dataset_source)
|
| 2062 |
-
self.max_frames = max_frames
|
| 2063 |
-
self.image_size = image_size
|
| 2064 |
-
self.video_frame_size = video_frame_size
|
| 2065 |
-
self.hub_token = hub_token or os.environ.get("HF_TOKEN")
|
| 2066 |
-
if self.is_hub:
|
| 2067 |
-
try:
|
| 2068 |
-
csv_local = download_from_hf_s2(self.source, csv_name, token=self.hub_token, repo_type="dataset")
|
| 2069 |
-
except Exception:
|
| 2070 |
-
alt = csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else csv_name + ".parquet"
|
| 2071 |
-
csv_local = download_from_hf_s2(self.source, alt, token=self.hub_token, repo_type="dataset")
|
| 2072 |
-
if str(csv_local).endswith(".parquet"):
|
| 2073 |
-
df = pd.read_parquet(csv_local)
|
| 2074 |
-
else:
|
| 2075 |
-
df = pd.read_csv(csv_local)
|
| 2076 |
-
self.df = df
|
| 2077 |
-
self.root = None
|
| 2078 |
-
else:
|
| 2079 |
-
root = Path(dataset_source)
|
| 2080 |
-
csv_path = root / csv_name
|
| 2081 |
-
parquet_path = root / csv_name.replace(".csv", ".parquet") if csv_name.endswith(".csv") else root / (csv_name + ".parquet")
|
| 2082 |
-
if csv_path.exists():
|
| 2083 |
-
self.df = pd.read_csv(csv_path)
|
| 2084 |
-
elif parquet_path.exists():
|
| 2085 |
-
self.df = pd.read_parquet(parquet_path)
|
| 2086 |
-
else:
|
| 2087 |
-
p = root / csv_name
|
| 2088 |
-
if p.exists():
|
| 2089 |
-
if p.suffix.lower() == ".parquet":
|
| 2090 |
-
self.df = pd.read_parquet(p)
|
| 2091 |
-
else:
|
| 2092 |
-
self.df = pd.read_csv(p)
|
| 2093 |
-
else:
|
| 2094 |
-
raise FileNotFoundError(f"Can't find {csv_name} in {dataset_source}")
|
| 2095 |
-
self.root = root
|
| 2096 |
-
self.image_transform = T.Compose([T.ToPILImage(), T.Resize(image_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
|
| 2097 |
-
self.video_transform = T.Compose([T.ToPILImage(), T.Resize(video_frame_size), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])
|
| 2098 |
-
def __len__(self):
|
| 2099 |
-
return len(self.df)
|
| 2100 |
-
def _maybe_download_from_hub(self, file_name: str) -> str:
|
| 2101 |
-
if self.root is not None:
|
| 2102 |
-
p = self.root / file_name
|
| 2103 |
-
if p.exists():
|
| 2104 |
-
return str(p)
|
| 2105 |
-
return download_from_hf_s2(self.source, file_name, token=self.hub_token, repo_type="dataset")
|
| 2106 |
-
def _read_video_frames(self, path: str, num_frames: int):
|
| 2107 |
-
video_frames, _, _ = torchvision.io.read_video(str(path), pts_unit='sec')
|
| 2108 |
-
total = len(video_frames)
|
| 2109 |
-
if total == 0:
|
| 2110 |
-
C, H, W = 3, self.video_frame_size[0], self.video_frame_size[1]
|
| 2111 |
-
return torch.zeros((num_frames, C, H, W), dtype=torch.float32)
|
| 2112 |
-
if total < num_frames:
|
| 2113 |
-
idxs = list(range(total)) + [total-1]*(num_frames-total)
|
| 2114 |
-
else:
|
| 2115 |
-
idxs = np.linspace(0, total-1, num_frames).round().astype(int).tolist()
|
| 2116 |
-
frames = []
|
| 2117 |
-
for i in idxs:
|
| 2118 |
-
arr = video_frames[i].numpy() if hasattr(video_frames[i], "numpy") else np.array(video_frames[i])
|
| 2119 |
-
frames.append(self.video_transform(arr))
|
| 2120 |
-
frames = torch.stack(frames, dim=0)
|
| 2121 |
-
return frames
|
| 2122 |
-
def __getitem__(self, idx):
|
| 2123 |
-
rec = self.df.iloc[idx]
|
| 2124 |
-
file_name = rec["file_name"]
|
| 2125 |
-
caption = rec["text"]
|
| 2126 |
-
if self.is_hub:
|
| 2127 |
-
local_path = self._maybe_download_from_hub(file_name)
|
| 2128 |
-
else:
|
| 2129 |
-
local_path = str(Path(self.root) / file_name)
|
| 2130 |
-
p = Path(local_path)
|
| 2131 |
-
suffix = p.suffix.lower()
|
| 2132 |
-
if suffix in IMAGE_EXTS:
|
| 2133 |
-
img = torchvision.io.read_image(local_path)
|
| 2134 |
-
if isinstance(img, torch.Tensor):
|
| 2135 |
-
img = img.permute(1,2,0).numpy()
|
| 2136 |
-
return {'type': 'image', 'image': self.image_transform(img), 'caption': caption, 'file_name': file_name}
|
| 2137 |
-
elif suffix in VIDEO_EXTS:
|
| 2138 |
-
frames = self._read_video_frames(local_path, self.max_frames)
|
| 2139 |
-
return {'type': 'video', 'frames': frames, 'caption': caption, 'file_name': file_name}
|
| 2140 |
-
else:
|
| 2141 |
-
raise RuntimeError(f"Unsupported media type: {local_path}")
|
| 2142 |
-
def load_pipeline_auto_s2(base_model_id: str, use_4bit: bool = False, bnb_config: object = None, torch_dtype=torch.float16):
|
| 2143 |
-
low = base_model_id.lower()
|
| 2144 |
-
is_chrono = "chrono" in low or "wan" in low or "video" in low
|
| 2145 |
-
is_qwen = "qwen" in low or "qwenimage" in low
|
| 2146 |
-
if is_chrono and CHRONOEDIT_AVAILABLE:
|
| 2147 |
-
if use_4bit and bnb_config is not None:
|
| 2148 |
-
pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
|
| 2149 |
-
else:
|
| 2150 |
-
pipe = ChronoEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
|
| 2151 |
-
elif is_qwen and QWENEDIT_AVAILABLE:
|
| 2152 |
-
pipe = QwenImageEditPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
|
| 2153 |
-
else:
|
| 2154 |
-
if use_4bit and BNB_AVAILABLE and bnb_config is not None:
|
| 2155 |
-
pipe = DiffusionPipeline.from_pretrained(base_model_id, quantization_config=bnb_config, torch_dtype=torch.float16)
|
| 2156 |
-
else:
|
| 2157 |
-
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
|
| 2158 |
-
return pipe
|
| 2159 |
-
def infer_target_for_task_s2(task_type: str, model_name: str) -> str:
|
| 2160 |
-
low = model_name.lower()
|
| 2161 |
-
if task_type == "prompt-lora" or "qwen" in low or "qwenedit" in low:
|
| 2162 |
-
return "text_encoder"
|
| 2163 |
-
if task_type == "text-video" or "chrono" in low or "wan" in low:
|
| 2164 |
-
return "transformer"
|
| 2165 |
-
return "unet"
|
| 2166 |
-
def attach_lora_s2(pipe, adapter_target: str, r: int = 8, alpha: int = 16, dropout: float = 0.0, use_adalora: bool = False):
|
| 2167 |
-
if adapter_target == "unet":
|
| 2168 |
-
target_module = pipe.unet
|
| 2169 |
-
attr = "unet"
|
| 2170 |
-
elif adapter_target == "transformer":
|
| 2171 |
-
target_module = pipe.transformer
|
| 2172 |
-
attr = "transformer"
|
| 2173 |
-
elif adapter_target == "text_encoder":
|
| 2174 |
-
target_module = pipe.text_encoder
|
| 2175 |
-
attr = "text_encoder"
|
| 2176 |
-
else:
|
| 2177 |
-
raise RuntimeError("Unknown adapter_target")
|
| 2178 |
-
target_modules = find_target_modules_s2(target_module)
|
| 2179 |
-
if use_adalora and ADALORA_AVAILABLE:
|
| 2180 |
-
lora_config = AdaLoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, init_r=4, lora_dropout=dropout)
|
| 2181 |
-
else:
|
| 2182 |
-
lora_config = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=dropout, bias="none", task_type="SEQ_2_SEQ_LM")
|
| 2183 |
-
peft_model = get_peft_model(target_module, lora_config)
|
| 2184 |
-
setattr(pipe, attr, peft_model)
|
| 2185 |
-
return pipe, attr
|
| 2186 |
-
def train_lora_accelerate_s2(base_model_id: str, dataset_source: str, csv_name: str, task_type: str, adapter_target_override: str, output_dir: str, epochs: int = 1, batch_size: int = 1, lr: float = 1e-4, max_train_steps: int = None, lora_r: int = 8, lora_alpha: int = 16, use_4bit: bool = False, enable_xformers: bool = False, use_adalora: bool = False, gradient_accumulation_steps: int = 1, mixed_precision: str = None, save_every_steps: int = 200, max_frames: int = 5):
|
| 2187 |
-
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision or ("fp16" if torch.cuda.is_available() else "no"))
|
| 2188 |
-
device = accelerator.device
|
| 2189 |
-
bnb_conf = None
|
| 2190 |
-
if use_4bit and BNB_AVAILABLE:
|
| 2191 |
-
bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
| 2192 |
-
pipe = load_pipeline_auto_s2(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
|
| 2193 |
-
if enable_xformers:
|
| 2194 |
-
try:
|
| 2195 |
-
if hasattr(pipe, "enable_xformers_memory_efficient_attention"):
|
| 2196 |
-
pipe.enable_xformers_memory_efficient_attention()
|
| 2197 |
-
elif hasattr(pipe, "enable_attention_slicing"):
|
| 2198 |
-
pipe.enable_attention_slicing()
|
| 2199 |
-
except Exception as e:
|
| 2200 |
-
print(f"Could not enable xformers: {e}")
|
| 2201 |
-
adapter_target = adapter_target_override if adapter_target_override else infer_target_for_task_s2(task_type, base_model_id)
|
| 2202 |
-
pipe, attr = attach_lora_s2(pipe, adapter_target, r=lora_r, alpha=lora_alpha, dropout=0.0, use_adalora=use_adalora)
|
| 2203 |
-
peft_module = getattr(pipe, attr)
|
| 2204 |
-
dataset = MediaTextDataset_s2(dataset_source, csv_name=csv_name, max_frames=max_frames)
|
| 2205 |
-
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=lambda x: x)
|
| 2206 |
-
trainable_params = [p for n,p in peft_module.named_parameters() if p.requires_grad]
|
| 2207 |
-
optimizer = torch.optim.AdamW(trainable_params, lr=lr)
|
| 2208 |
-
peft_module, optimizer, dataloader = accelerator.prepare(peft_module, optimizer, dataloader)
|
| 2209 |
-
logs = []
|
| 2210 |
-
global_step = 0
|
| 2211 |
-
loss_fn = nn.MSELoss()
|
| 2212 |
-
timesteps = None
|
| 2213 |
-
if hasattr(pipe, "scheduler"):
|
| 2214 |
-
try:
|
| 2215 |
-
pipe.scheduler.set_timesteps(50, device=device)
|
| 2216 |
-
timesteps = pipe.scheduler.timesteps
|
| 2217 |
-
except Exception:
|
| 2218 |
-
pass
|
| 2219 |
-
for epoch in range(int(epochs)):
|
| 2220 |
-
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
|
| 2221 |
-
for batch in pbar:
|
| 2222 |
-
example = batch[0]
|
| 2223 |
-
if example["type"] == "image":
|
| 2224 |
-
img = example["image"].unsqueeze(0).to(device)
|
| 2225 |
-
caption = [example["caption"]]
|
| 2226 |
-
if not hasattr(pipe, "encode_prompt"):
|
| 2227 |
-
raise RuntimeError("Pipeline lacks encode_prompt - cannot encode prompts")
|
| 2228 |
-
prompt_embeds, _ = pipe.encode_prompt(prompt=caption, negative_prompt=None, do_classifier_free_guidance=True, num_videos_per_prompt=1, device=device)
|
| 2229 |
-
if not hasattr(pipe, "vae"):
|
| 2230 |
-
raise RuntimeError("Pipeline lacks VAE - required for latent conversion")
|
| 2231 |
-
with torch.no_grad():
|
| 2232 |
-
latents = pipe.vae.encode(img.to(device)).latent_dist.sample() * pipe.vae.config.scaling_factor
|
| 2233 |
-
noise = torch.randn_like(latents).to(device)
|
| 2234 |
-
t = timesteps[torch.randint(0, len(timesteps), (1,)).item()].to(device) if timesteps is not None else torch.tensor(1, device=device)
|
| 2235 |
-
noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
|
| 2236 |
-
out = peft_module(noisy_latents, t.expand(noisy_latents.shape[0]), encoder_hidden_states=prompt_embeds)
|
| 2237 |
-
if hasattr(out, "sample"):
|
| 2238 |
-
noise_pred = out.sample
|
| 2239 |
-
elif isinstance(out, tuple):
|
| 2240 |
-
noise_pred = out[0]
|
| 2241 |
-
else:
|
| 2242 |
-
noise_pred = out
|
| 2243 |
-
loss = loss_fn(noise_pred, noise)
|
| 2244 |
-
else:
|
| 2245 |
-
if not CHRONOEDIT_AVAILABLE:
|
| 2246 |
-
raise RuntimeError("ChronoEdit training requested but not installed in environment")
|
| 2247 |
-
frames = example["frames"].unsqueeze(0).to(device)
|
| 2248 |
-
frames_np = frames.squeeze(0).permute(0,2,3,1).cpu().numpy().tolist()
|
| 2249 |
-
video_tensor = pipe.video_processor.preprocess(frames_np, height=frames.shape[-2], width=frames.shape[-1]).to(device)
|
| 2250 |
-
latents_out = pipe.prepare_latents(video_tensor, batch_size=1, num_channels_latents=pipe.vae.config.z_dim, height=video_tensor.shape[-2], width=video_tensor.shape[-1], num_frames=frames.shape[1], dtype=video_tensor.dtype, device=device)
|
| 2251 |
-
latents, condition = latents_out
|
| 2252 |
-
noise = torch.randn_like(latents).to(device)
|
| 2253 |
-
t = timesteps[torch.randint(0, len(timesteps), (1,)).item()].to(device)
|
| 2254 |
-
noisy_latents = pipe.scheduler.add_noise(latents, noise, t)
|
| 2255 |
-
latent_model_input = torch.cat([noisy_latents, condition], dim=1)
|
| 2256 |
-
out = peft_module(hidden_states=latent_model_input, timestep=t.unsqueeze(0).expand(latent_model_input.shape[0]))
|
| 2257 |
-
noise_pred = out[0]
|
| 2258 |
-
loss = loss_fn(noise_pred, noise)
|
| 2259 |
-
accelerator.backward(loss)
|
| 2260 |
-
optimizer.step()
|
| 2261 |
-
optimizer.zero_grad()
|
| 2262 |
-
global_step += 1
|
| 2263 |
-
logs.append(f"step {global_step} loss {loss.item():.6f}")
|
| 2264 |
-
pbar.set_postfix({"loss": f"{loss.item():.6f}"})
|
| 2265 |
-
if max_train_steps and global_step >= max_train_steps:
|
| 2266 |
-
break
|
| 2267 |
-
if global_step % save_every_steps == 0:
|
| 2268 |
-
out_sub = Path(output_dir) / f"lora_step_{global_step}"
|
| 2269 |
-
out_sub.mkdir(parents=True, exist_ok=True)
|
| 2270 |
-
try:
|
| 2271 |
-
peft_module.save_pretrained(str(out_sub))
|
| 2272 |
-
except Exception:
|
| 2273 |
-
torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(out_sub / "adapter_state_dict.pt"))
|
| 2274 |
-
if max_train_steps and global_step >= max_train_steps:
|
| 2275 |
-
break
|
| 2276 |
-
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 2277 |
-
try:
|
| 2278 |
-
peft_module.save_pretrained(output_dir)
|
| 2279 |
-
except Exception:
|
| 2280 |
-
torch.save({k: v.cpu() for k,v in peft_module.state_dict().items()}, str(Path(output_dir) / "adapter_state_dict.pt"))
|
| 2281 |
-
return output_dir, logs
|
| 2282 |
-
def test_generation_load_and_run_s2(base_model_id: str, adapter_dir: str, adapter_target: str, prompt: str, use_4bit: bool = False):
|
| 2283 |
-
bnb_conf = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") if use_4bit and BNB_AVAILABLE else None
|
| 2284 |
-
pipe = load_pipeline_auto_s2(base_model_id, use_4bit=use_4bit, bnb_config=bnb_conf, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
| 2285 |
-
try:
|
| 2286 |
-
if adapter_target == "unet" and hasattr(pipe, "unet"):
|
| 2287 |
-
pipe.unet.load_adapter(adapter_dir)
|
| 2288 |
-
elif adapter_target == "transformer" and hasattr(pipe, "transformer"):
|
| 2289 |
-
pipe.transformer.load_adapter(adapter_dir)
|
| 2290 |
-
elif adapter_target == "text_encoder" and hasattr(pipe, "text_encoder"):
|
| 2291 |
-
pipe.text_encoder.load_adapter(adapter_dir)
|
| 2292 |
-
except Exception as e:
|
| 2293 |
-
print(f"Adapter load warning: {e}")
|
| 2294 |
-
pipe.to(DEVICE)
|
| 2295 |
-
out = pipe(prompt=prompt, num_inference_steps=8)
|
| 2296 |
-
if hasattr(out, "images"):
|
| 2297 |
-
return out.images[0]
|
| 2298 |
-
elif hasattr(out, "frames"):
|
| 2299 |
-
frames = out.frames[0]
|
| 2300 |
-
return Image.fromarray((frames[-1] * 255).clip(0,255).astype("uint8"))
|
| 2301 |
-
raise RuntimeError("No images/frames returned")
|
| 2302 |
-
def upload_adapter_s2(local_dir: str, repo_id: str) -> str:
|
| 2303 |
-
token = os.environ.get("HF_TOKEN")
|
| 2304 |
-
if token is None:
|
| 2305 |
-
raise RuntimeError("HF_TOKEN not set in environment for upload")
|
| 2306 |
-
create_repo(repo_id, exist_ok=True)
|
| 2307 |
-
upload_folder(folder_path=local_dir, repo_id=repo_id, repo_type="model", token=token)
|
| 2308 |
-
return f"https://huggingface.co/{repo_id}"
|
| 2309 |
-
def boost_info_text_s2(use_4bit: bool, enable_xformers: bool, mixed_precision: str, device_type: str):
|
| 2310 |
-
lines = [f"Device: {device_type.upper()}"]
|
| 2311 |
-
lines.append("4-bit QLoRA enabled: ~4x memory saving." if use_4bit and BNB_AVAILABLE else "QLoRA disabled.")
|
| 2312 |
-
lines.append("xFormers/FlashAttention: memory-efficient attention enabled." if enable_xformers and XFORMERS_AVAILABLE else "xFormers disabled.")
|
| 2313 |
-
lines.append(f"Mixed precision: {mixed_precision}" if mixed_precision else "Mixed precision: default.")
|
| 2314 |
-
return "\n".join(lines)
|
| 2315 |
-
def run_all_ui_s2(base_model_id: str, dataset_source: str, csv_name: str, task_type: str, adapter_target_override: str, lora_r: int, lora_alpha: int, epochs: int, batch_size: int, lr: float, max_train_steps: int, output_dir: str, upload_repo: str, use_4bit: bool, enable_xformers: bool, use_adalora: bool, grad_accum: int, mixed_precision: str, save_every_steps: int):
|
| 2316 |
-
adapter_target = adapter_target_override if adapter_target_override else infer_target_for_task_s2(task_type, base_model_id)
|
| 2317 |
-
try:
|
| 2318 |
-
out_dir, logs = train_lora_accelerate_s2(base_model_id, dataset_source, csv_name, task_type, adapter_target, output_dir, epochs=epochs, lr=lr, max_train_steps=(max_train_steps if max_train_steps>0 else None), lora_r=lora_r, lora_alpha=lora_alpha, use_4bit=use_4bit, enable_xformers=enable_xformers, use_adalora=use_adalora, gradient_accumulation_steps=grad_accum, mixed_precision=(mixed_precision if mixed_precision != "none" else None), save_every_steps=save_every_steps)
|
| 2319 |
-
except Exception as e:
|
| 2320 |
-
return f"Training failed: {e}", None, None
|
| 2321 |
-
link = None
|
| 2322 |
-
if upload_repo:
|
| 2323 |
-
try:
|
| 2324 |
-
link = upload_adapter_s2(out_dir, upload_repo)
|
| 2325 |
-
except Exception as e:
|
| 2326 |
-
link = f"Upload failed: {e}"
|
| 2327 |
-
try:
|
| 2328 |
-
ds = MediaTextDataset_s2(dataset_source, csv_name=csv_name, max_frames=5)
|
| 2329 |
-
test_prompt = ds.df.iloc[0]["text"] if len(ds.df) > 0 else "A cat on a skateboard"
|
| 2330 |
-
except Exception:
|
| 2331 |
-
test_prompt = "A cat on a skateboard"
|
| 2332 |
-
test_img = None
|
| 2333 |
-
try:
|
| 2334 |
-
test_img = test_generation_load_and_run_s2(base_model_id, out_dir, adapter_target, test_prompt, use_4bit=use_4bit)
|
| 2335 |
-
except Exception as e:
|
| 2336 |
-
print(f"Test gen failed: {e}")
|
| 2337 |
-
return "\n".join(logs[-200:]), test_img, link
|
| 2338 |
-
def build_ui_s2():
|
| 2339 |
-
with gr.Blocks() as demo_s2:
|
| 2340 |
-
gr.Markdown("# Universal LoRA Trainer — Quantization & Speedups (single-file)")
|
| 2341 |
-
with gr.Row():
|
| 2342 |
-
base_model = gr.Textbox(label="Base model id (Diffusers / ChronoEdit / Qwen)", value="runwayml/stable-diffusion-v1-5")
|
| 2343 |
-
dataset_source = gr.Textbox(label="Dataset folder or HF repo (username/repo)", value="./dataset")
|
| 2344 |
-
csv_name = gr.Textbox(label="CSV/Parquet filename", value="dataset.csv")
|
| 2345 |
-
task_type = gr.Dropdown(label="Task type", choices=["text-image", "text-video", "prompt-lora"], value="text-image")
|
| 2346 |
-
adapter_target_override = gr.Textbox(label="Adapter target override (leave blank for auto)", value="")
|
| 2347 |
-
lora_r = gr.Slider(1, 64, value=8, step=1, label="LoRA rank (r)")
|
| 2348 |
-
lora_alpha = gr.Slider(1, 128, value=16, step=1, label="LoRA alpha")
|
| 2349 |
-
epochs = gr.Number(label="Epochs", value=1)
|
| 2350 |
-
batch_size = gr.Number(label="Batch size (per device)", value=1)
|
| 2351 |
-
lr = gr.Number(label="Learning rate", value=1e-4)
|
| 2352 |
-
max_train_steps = gr.Number(label="Max train steps (0 = unlimited)", value=0)
|
| 2353 |
-
save_every_steps = gr.Number(label="Save every steps", value=200)
|
| 2354 |
-
output_dir = gr.Textbox(label="Local output dir for adapter", value="./adapter_out")
|
| 2355 |
-
upload_repo = gr.Textbox(label="Upload adapter to HF repo (optional)", value="")
|
| 2356 |
-
with gr.Row():
|
| 2357 |
-
use_4bit = gr.Checkbox(label="Enable 4-bit QLoRA (bitsandbytes)", value=False)
|
| 2358 |
-
enable_xformers = gr.Checkbox(label="Enable xFormers / memory efficient attention", value=False)
|
| 2359 |
-
use_adalora = gr.Checkbox(label="Use AdaLoRA (if available in peft)", value=False)
|
| 2360 |
-
grad_accum = gr.Number(label="Gradient accumulation steps", value=1)
|
| 2361 |
-
mixed_precision = gr.Radio(choices=["none", "fp16", "bf16"], value=("fp16" if torch.cuda.is_available() else "none"), label="Mixed precision")
|
| 2362 |
-
boost_info = gr.Textbox(label="Expected boost / notes", value="", lines=6)
|
| 2363 |
-
start_btn = gr.Button("Start Training")
|
| 2364 |
-
with gr.Row():
|
| 2365 |
-
logs = gr.Textbox(label="Training logs (tail)", lines=18)
|
| 2366 |
-
sample_image = gr.Image(label="Sample generated frame after training")
|
| 2367 |
-
upload_link = gr.Textbox(label="Upload link / status")
|
| 2368 |
-
def on_start(base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, grad_accum_val, mixed_precision_val, save_every_steps):
|
| 2369 |
-
boost_text = boost_info_text_s2(use_4bit_val, enable_xformers_val, mixed_precision_val, "gpu" if torch.cuda.is_available() else "cpu")
|
| 2370 |
-
logs_out, sample, link = run_all_ui_s2(base_model, dataset_source, csv_name, task_type, adapter_target_override, int(lora_r), int(lora_alpha), int(epochs), int(batch_size), float(lr), int(max_train_steps), output_dir, upload_repo, use_4bit_val, enable_xformers_val, use_adalora_val, int(grad_accum_val), mixed_precision_val, int(save_every_steps))
|
| 2371 |
-
return boost_text + "\n\n" + logs_out, sample, link
|
| 2372 |
-
start_btn.click(on_start, inputs=[base_model, dataset_source, csv_name, task_type, adapter_target_override, lora_r, lora_alpha, epochs, batch_size, lr, max_train_steps, output_dir, upload_repo, use_4bit, enable_xformers, use_adalora, grad_accum, mixed_precision, save_every_steps], outputs=[boost_info, sample_image, upload_link])
|
| 2373 |
-
return demo_s2
|
| 2374 |
-
def run_all_ui_s3(base_model, dataset_src, csv_name, short_col, long_col, batch_size, num_workers, r, a, ep, lr, max_rec, repo_):
|
| 2375 |
-
gen = train_lora_stream_s3(base_model, dataset_src, csv_name, [short_col, long_col], epochs=int(ep), lr=float(lr), r=int(r), alpha=int(a), batch_size=int(batch_size), num_workers=int(num_workers), max_train_records=int(max_rec))
|
| 2376 |
-
for item in gen:
|
| 2377 |
-
yield item
|
| 2378 |
-
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 2379 |
-
if not repo_ or not HF_TOKEN:
|
| 2380 |
-
raise ValueError("HF repo ID and HF_TOKEN required for upload.")
|
| 2381 |
-
repo_ = repo_.strip()
|
| 2382 |
-
create_repo(repo_, repo_type="model", exist_ok=True, token=HF_TOKEN)
|
| 2383 |
-
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 2384 |
-
lora_module.save_pretrained(tmp_dir)
|
| 2385 |
-
upload_folder(folder_path=tmp_dir, repo_id=repo_, repo_type="model", token=HF_TOKEN)
|
| 2386 |
-
link = f"https://huggingface.co/{repo_}"
|
| 2387 |
-
yield "\n".join(logs) + f"\n[INFO] ✅ Uploaded successfully: {link}\n", link
|
| 2388 |
-
def run_ui_s3_final():
|
| 2389 |
-
with gr.Blocks() as demo_s3_final:
|
| 2390 |
-
gr.Markdown("# 🌐 Universal Dynamic LoRA Trainer & Inference")
|
| 2391 |
-
with gr.Row():
|
| 2392 |
-
base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
|
| 2393 |
-
dataset = gr.Textbox(label="Dataset folder or HF repo", value="rahul7star/prompt-enhancer-dataset-01")
|
| 2394 |
-
csvname = gr.Textbox(label="CSV/Parquet file", value="train-00000-of-00001.csv")
|
| 2395 |
-
short_col = gr.Textbox(label="Short prompt column", value="short_prompt")
|
| 2396 |
-
long_col = gr.Textbox(label="Long prompt column", value="long_col")
|
| 2397 |
-
repo = gr.Textbox(label="HF repo to upload LoRA", value="rahul7star/gemma-3-270m-ccebc0")
|
| 2398 |
-
with gr.Row():
|
| 2399 |
-
batch_size = gr.Number(value=1, label="Batch size")
|
| 2400 |
-
num_workers = gr.Number(value=0, label="DataLoader num_workers")
|
| 2401 |
-
r = gr.Number(value=8, label="LoRA rank")
|
| 2402 |
-
a = gr.Number(value=16, label="LoRA alpha")
|
| 2403 |
-
ep = gr.Number(value=1, label="Epochs")
|
| 2404 |
-
lr = gr.Number(value=1e-4, label="Learning rate")
|
| 2405 |
-
max_records = gr.Number(value=1000, label="Max training records")
|
| 2406 |
-
logs = gr.Textbox(label="Logs (streaming)", lines=25)
|
| 2407 |
-
btn = gr.Button("🚀 Start Training")
|
| 2408 |
-
btn.click(fn=run_all_ui_s3,
|
| 2409 |
-
inputs=[base_model, dataset, csvname, short_col, long_col, batch_size, num_workers, r, a, ep, lr, max_records, repo],
|
| 2410 |
-
outputs=[logs],
|
| 2411 |
-
queue=True)
|
| 2412 |
-
with gr.Tab("Inference (CPU)"):
|
| 2413 |
-
inf_base_model = gr.Textbox(label="Base model", value="google/gemma-3-4b-it")
|
| 2414 |
-
inf_lora_repo = gr.Textbox(label="LoRA HF repo", value="rahul7star/gemma-3-270m-ccebc0")
|
| 2415 |
-
short_prompt = gr.Textbox(label="Short prompt")
|
| 2416 |
-
long_prompt_out = gr.Textbox(label="Generated long prompt", lines=5)
|
| 2417 |
-
inf_btn = gr.Button("📝 Generate Long Prompt")
|
| 2418 |
-
inf_btn.click(fn=generate_long_prompt_cpu_s3,
|
| 2419 |
-
inputs=[inf_base_model, inf_lora_repo, short_prompt],
|
| 2420 |
-
outputs=[long_prompt_out])
|
| 2421 |
-
with gr.Tab("Code Explain"):
|
| 2422 |
-
explain_md = gr.Markdown("""
|
| 2423 |
-
### Universal LoRA Trainer & Inference - Code Explanation
|
| 2424 |
-
#### 1. CORE MECHANISMS
|
| 2425 |
* **PEFT/LoRA**: Parameter-Efficient Fine-Tuning. Only low-rank matrices ($A$ and $B$) are trained for low-rank updates ($W' = W + B A$). This drastically reduces trainable parameters.
|
| 2426 |
* **QLoRA (4-bit)**: Loads the base model weights in 4-bit precision (NF4 with double quantization) using `bitsandbytes`, massively reducing VRAM usage while training LoRA adapters.
|
| 2427 |
* **Accelerator**: Manages device placement (CPU/GPU), mixed precision (`fp16`/`bf16`), and gradient accumulation for stable large-batch training simulation.
|
|
@@ -2429,27 +1855,37 @@ def run_ui_s3_final():
|
|
| 2429 |
* **Gradient Accumulation**: Simulates larger batch sizes by accumulating gradients over several forward/backward passes before an optimization step.
|
| 2430 |
* **Gradient Clipping**: Limits the maximum norm of the gradients (`max_grad_norm`) to prevent exploding gradients during training.
|
| 2431 |
* **Memory Optimization**: Optional use of `xFormers` (FlashAttention or memory-efficient attention) to reduce memory footprint and speed up training on compatible GPUs.
|
| 2432 |
-
|
|
|
|
|
|
|
| 2433 |
* **Streaming Datasets**: Uses `datasets` streaming mode to handle very large datasets without loading all into RAM.
|
| 2434 |
* **Data Cleaning**: Removes HTML tags, normalizes whitespace, redacts PII, and removes URLs/emails.
|
| 2435 |
* **Advanced Filtering**: Includes optional filters for text length, word repetition, language detection, and basic toxicity detection (via `unitary/toxic-bert`).
|
| 2436 |
* **Data Augmentation**: Supports **Back-Translation (BT)** for introducing paraphrasing variations and **Counterfactual Data Augmentation (CDA)** for controlled bias testing (e.g., swapping gendered pronouns).
|
| 2437 |
* **Synthetic Data Generation**: Uses a specified LLM to generate new training examples based on an initial prompt template.
|
| 2438 |
* **Deduplication**: Implements both **Exact** and **Semantic (MinHash LSH)** deduplication to prevent data contamination during iterative fine-tuning.
|
| 2439 |
-
|
|
|
|
|
|
|
| 2440 |
* **SFT (Supervised Fine-Tuning)**: Standard fine-tuning, supports **Conversation** and **Reasoning/Tool Use (CoT)** formatting styles.
|
| 2441 |
* **DPO (Direct Preference Optimization)**: Trains directly on preference pairs (chosen vs. rejected), using the `trl` library.
|
| 2442 |
* **Task-Specific Heads**: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`).
|
| 2443 |
* **Seq2Seq**: For translation/summarization tasks, using `Seq2SeqTrainer`.
|
| 2444 |
* **Diffusion (Text-to-Image/DreamBooth)**: Fine-tunes the UNet (and optionally Text Encoder) using LoRA for image generation tasks, with custom image/video data handling.
|
| 2445 |
-
|
|
|
|
|
|
|
| 2446 |
* **Model From Scratch**: Allows initializing a model (e.g., Llama, Mistral) from a config rather than a pre-trained checkpoint, with optional auto-configuration based on expected training scale.
|
| 2447 |
* **Multi-Adapter Merging**: Advanced feature to combine multiple existing LoRA adapters into a single, new adapter using weighted averaging (`slerp`, `linear`, etc.).
|
| 2448 |
-
|
|
|
|
|
|
|
| 2449 |
* **Hugging Face Hub Integration**: All trained artifacts (full model/LoRA adapter) are automatically pushed to a specified repository on the HF Hub using the provided token.
|
| 2450 |
* **Model Card Generation**: Automatically generates a `README.md` detailing training parameters and model provenance.
|
| 2451 |
* **Inference Tabs**: Separate UI for testing the trained LoRA adapter on CPU (for Gemma/LoRA) or various pipeline modes on GPU.
|
| 2452 |
""")
|
| 2453 |
-
|
|
|
|
|
|
|
| 2454 |
if __name__ == "__main__":
|
| 2455 |
-
|
|
|
|
| 49 |
PhiConfig, PhiForCausalLM, Qwen2Config, Qwen2ForCausalLM,
|
| 50 |
DataCollatorForLanguageModeling, DefaultDataCollator, Adafactor
|
| 51 |
)
|
| 52 |
+
from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training, AdaLoraConfig
|
| 53 |
from trl import SFTTrainer, DPOTrainer
|
| 54 |
from diffusers import (
|
| 55 |
UNet2DConditionModel, DDPMScheduler, AutoencoderKL, DiffusionPipeline,
|
|
|
|
| 157 |
break
|
| 158 |
return (loss, outputs) if return_outputs else loss
|
| 159 |
|
| 160 |
+
def _deduplication_generator(dataset, text_col, method, threshold, num_perm):
|
| 161 |
+
if method == 'Exacta':
|
| 162 |
+
seen_texts = set()
|
| 163 |
+
for example in dataset:
|
| 164 |
+
text = example.get(text_col, "")
|
| 165 |
+
if text and isinstance(text, str):
|
| 166 |
+
if text not in seen_texts:
|
| 167 |
+
seen_texts.add(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
yield example
|
| 169 |
+
else:
|
| 170 |
+
yield example
|
| 171 |
+
elif method == 'Semántica (MinHash)':
|
| 172 |
+
lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
|
| 173 |
+
for i, example in enumerate(dataset):
|
| 174 |
+
text = example.get(text_col, "")
|
| 175 |
+
if text and isinstance(text, str) and text.strip():
|
| 176 |
+
m = MinHash(num_perm=num_perm)
|
| 177 |
+
for d in text.split():
|
| 178 |
+
m.update(d.encode('utf8'))
|
| 179 |
+
if not lsh.query(m):
|
| 180 |
+
lsh.insert(f"key_{i}", m)
|
| 181 |
yield example
|
| 182 |
+
else:
|
| 183 |
+
yield example
|
| 184 |
+
else:
|
| 185 |
+
yield from dataset
|
| 186 |
+
|
| 187 |
+
def _create_deduplicated_iterable_dataset(dataset, text_col, method, threshold=0.85, num_perm=128):
|
| 188 |
+
return IterableDataset.from_generator(
|
| 189 |
+
_deduplication_generator,
|
| 190 |
+
gen_kwargs={
|
| 191 |
+
"dataset": dataset,
|
| 192 |
+
"text_col": text_col,
|
| 193 |
+
"method": method,
|
| 194 |
+
"threshold": threshold,
|
| 195 |
+
"num_perm": num_perm,
|
| 196 |
+
}
|
| 197 |
+
)
|
| 198 |
|
| 199 |
@spaces.GPU()
|
| 200 |
def hf_login(token):
|
|
|
|
| 1842 |
inputs=[inf_task_mode, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in, inf_temperature, inf_top_p, inf_max_new_tokens],
|
| 1843 |
outputs=[inf_text_out, inf_model_id, inf_text_in, inf_context_in, inf_image_in, inf_audio_in]
|
| 1844 |
)
|
| 1845 |
+
with gr.Tab("5. Explicación del Código y Mecanismos Avanzados"):
|
| 1846 |
+
gr.Markdown("""
|
| 1847 |
+
### 🧠 Explicación del Código y Mecanismos Avanzados
|
| 1848 |
+
""")
|
| 1849 |
+
gr.Markdown("#### 1. CORE MECHANISMS")
|
| 1850 |
+
gr.Markdown("""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1851 |
* **PEFT/LoRA**: Parameter-Efficient Fine-Tuning. Only low-rank matrices ($A$ and $B$) are trained for low-rank updates ($W' = W + B A$). This drastically reduces trainable parameters.
|
| 1852 |
* **QLoRA (4-bit)**: Loads the base model weights in 4-bit precision (NF4 with double quantization) using `bitsandbytes`, massively reducing VRAM usage while training LoRA adapters.
|
| 1853 |
* **Accelerator**: Manages device placement (CPU/GPU), mixed precision (`fp16`/`bf16`), and gradient accumulation for stable large-batch training simulation.
|
|
|
|
| 1855 |
* **Gradient Accumulation**: Simulates larger batch sizes by accumulating gradients over several forward/backward passes before an optimization step.
|
| 1856 |
* **Gradient Clipping**: Limits the maximum norm of the gradients (`max_grad_norm`) to prevent exploding gradients during training.
|
| 1857 |
* **Memory Optimization**: Optional use of `xFormers` (FlashAttention or memory-efficient attention) to reduce memory footprint and speed up training on compatible GPUs.
|
| 1858 |
+
""")
|
| 1859 |
+
gr.Markdown("#### 2. DATA PROCESSING & AUGMENTATION")
|
| 1860 |
+
gr.Markdown("""
|
| 1861 |
* **Streaming Datasets**: Uses `datasets` streaming mode to handle very large datasets without loading all into RAM.
|
| 1862 |
* **Data Cleaning**: Removes HTML tags, normalizes whitespace, redacts PII, and removes URLs/emails.
|
| 1863 |
* **Advanced Filtering**: Includes optional filters for text length, word repetition, language detection, and basic toxicity detection (via `unitary/toxic-bert`).
|
| 1864 |
* **Data Augmentation**: Supports **Back-Translation (BT)** for introducing paraphrasing variations and **Counterfactual Data Augmentation (CDA)** for controlled bias testing (e.g., swapping gendered pronouns).
|
| 1865 |
* **Synthetic Data Generation**: Uses a specified LLM to generate new training examples based on an initial prompt template.
|
| 1866 |
* **Deduplication**: Implements both **Exact** and **Semantic (MinHash LSH)** deduplication to prevent data contamination during iterative fine-tuning.
|
| 1867 |
+
""")
|
| 1868 |
+
gr.Markdown("#### 3. TRAINING MODES")
|
| 1869 |
+
gr.Markdown("""
|
| 1870 |
* **SFT (Supervised Fine-Tuning)**: Standard fine-tuning, supports **Conversation** and **Reasoning/Tool Use (CoT)** formatting styles.
|
| 1871 |
* **DPO (Direct Preference Optimization)**: Trains directly on preference pairs (chosen vs. rejected), using the `trl` library.
|
| 1872 |
* **Task-Specific Heads**: Supports **Sequence Classification**, **Token Classification (NER)**, and **Question Answering** by loading appropriate model heads (`AutoModelFor...`).
|
| 1873 |
* **Seq2Seq**: For translation/summarization tasks, using `Seq2SeqTrainer`.
|
| 1874 |
* **Diffusion (Text-to-Image/DreamBooth)**: Fine-tunes the UNet (and optionally Text Encoder) using LoRA for image generation tasks, with custom image/video data handling.
|
| 1875 |
+
""")
|
| 1876 |
+
gr.Markdown("#### 4. MODEL INITIALIZATION")
|
| 1877 |
+
gr.Markdown("""
|
| 1878 |
* **Model From Scratch**: Allows initializing a model (e.g., Llama, Mistral) from a config rather than a pre-trained checkpoint, with optional auto-configuration based on expected training scale.
|
| 1879 |
* **Multi-Adapter Merging**: Advanced feature to combine multiple existing LoRA adapters into a single, new adapter using weighted averaging (`slerp`, `linear`, etc.).
|
| 1880 |
+
""")
|
| 1881 |
+
gr.Markdown("#### 5. OUTPUT & DEPLOYMENT")
|
| 1882 |
+
gr.Markdown("""
|
| 1883 |
* **Hugging Face Hub Integration**: All trained artifacts (full model/LoRA adapter) are automatically pushed to a specified repository on the HF Hub using the provided token.
|
| 1884 |
* **Model Card Generation**: Automatically generates a `README.md` detailing training parameters and model provenance.
|
| 1885 |
* **Inference Tabs**: Separate UI for testing the trained LoRA adapter on CPU (for Gemma/LoRA) or various pipeline modes on GPU.
|
| 1886 |
""")
|
| 1887 |
+
gr.Markdown("### 💡 Hardware Fallback")
|
| 1888 |
+
gr.Markdown(f"If CUDA/GPU is unavailable, the system defaults to CPU: **{device.upper()}**. Training and inference on CPU will be significantly slower, especially for large models or Diffusers.")
|
| 1889 |
+
|
| 1890 |
if __name__ == "__main__":
|
| 1891 |
+
demo.queue().launch(debug=True, share=True)
|