Update app_flash.py
Browse files- app_flash.py +21 -16
app_flash.py
CHANGED
|
@@ -70,7 +70,7 @@ def train_flashpack_model(
|
|
| 70 |
dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
|
| 71 |
model_name: str = "gpt2",
|
| 72 |
max_length: int = 32,
|
| 73 |
-
|
| 74 |
push_to_hub: bool = False,
|
| 75 |
hf_repo: str = "rahul7star/FlashPack",
|
| 76 |
) -> tuple:
|
|
@@ -79,14 +79,12 @@ def train_flashpack_model(
|
|
| 79 |
print("📦 Loading dataset...")
|
| 80 |
dataset = load_dataset(dataset_name, split="train")
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
|
|
|
|
| 85 |
|
| 86 |
-
|
| 87 |
-
print(f"⚡ Using subset: {len(dataset)} examples for quick training")
|
| 88 |
-
|
| 89 |
-
# 2️⃣ Setup tokenizer and encoder
|
| 90 |
tokenizer, embed_model, encode_fn = build_encoder(model_name=model_name, max_length=max_length)
|
| 91 |
|
| 92 |
# 3️⃣ Encode dataset (CPU-friendly)
|
|
@@ -95,28 +93,35 @@ def train_flashpack_model(
|
|
| 95 |
for i, item in enumerate(dataset):
|
| 96 |
short_list.append(encode_fn(item["short_prompt"]))
|
| 97 |
long_list.append(encode_fn(item["long_prompt"]))
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
gc.collect()
|
| 101 |
|
| 102 |
short_embeddings = torch.vstack(short_list)
|
| 103 |
long_embeddings = torch.vstack(long_list)
|
|
|
|
| 104 |
|
| 105 |
-
# 4️⃣ Initialize model
|
| 106 |
model = GemmaTrainer(
|
| 107 |
input_dim=short_embeddings.shape[1],
|
| 108 |
-
hidden_dim=min(512, short_embeddings.shape[1]),
|
| 109 |
output_dim=long_embeddings.shape[1],
|
| 110 |
).to(device)
|
| 111 |
|
| 112 |
-
# 5️⃣ Training loop
|
| 113 |
criterion = nn.MSELoss()
|
| 114 |
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
| 115 |
-
max_epochs = 50
|
| 116 |
tolerance = 1e-4
|
| 117 |
-
batch_size = 32
|
| 118 |
|
| 119 |
-
print("🚀 Training FlashPack mapper model (
|
| 120 |
n = short_embeddings.shape[0]
|
| 121 |
for epoch in range(max_epochs):
|
| 122 |
model.train()
|
|
|
|
| 70 |
dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
|
| 71 |
model_name: str = "gpt2",
|
| 72 |
max_length: int = 32,
|
| 73 |
+
max_encode: int = 2000, # maximum number of prompts to encode
|
| 74 |
push_to_hub: bool = False,
|
| 75 |
hf_repo: str = "rahul7star/FlashPack",
|
| 76 |
) -> tuple:
|
|
|
|
| 79 |
print("📦 Loading dataset...")
|
| 80 |
dataset = load_dataset(dataset_name, split="train")
|
| 81 |
|
| 82 |
+
# Limit dataset to max_encode prompts
|
| 83 |
+
limit = min(max_encode, len(dataset))
|
| 84 |
+
dataset = dataset.select(range(limit))
|
| 85 |
+
print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
|
| 86 |
|
| 87 |
+
# 2️⃣ Setup tokenizer & encoder
|
|
|
|
|
|
|
|
|
|
| 88 |
tokenizer, embed_model, encode_fn = build_encoder(model_name=model_name, max_length=max_length)
|
| 89 |
|
| 90 |
# 3️⃣ Encode dataset (CPU-friendly)
|
|
|
|
| 93 |
for i, item in enumerate(dataset):
|
| 94 |
short_list.append(encode_fn(item["short_prompt"]))
|
| 95 |
long_list.append(encode_fn(item["long_prompt"]))
|
| 96 |
+
|
| 97 |
+
# Exit early if we hit max_encode
|
| 98 |
+
if (i + 1) >= max_encode:
|
| 99 |
+
print(f"⚡ Reached max encode limit: {max_encode} prompts, stopping early.")
|
| 100 |
+
break
|
| 101 |
+
|
| 102 |
+
# Progress logging
|
| 103 |
+
if (i + 1) % 50 == 0:
|
| 104 |
+
print(f" → Encoded {i+1}/{limit} prompts")
|
| 105 |
gc.collect()
|
| 106 |
|
| 107 |
short_embeddings = torch.vstack(short_list)
|
| 108 |
long_embeddings = torch.vstack(long_list)
|
| 109 |
+
print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
|
| 110 |
|
| 111 |
+
# 4️⃣ Initialize and train model (same as before)
|
| 112 |
model = GemmaTrainer(
|
| 113 |
input_dim=short_embeddings.shape[1],
|
| 114 |
+
hidden_dim=min(512, short_embeddings.shape[1]),
|
| 115 |
output_dim=long_embeddings.shape[1],
|
| 116 |
).to(device)
|
| 117 |
|
|
|
|
| 118 |
criterion = nn.MSELoss()
|
| 119 |
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
| 120 |
+
max_epochs = 50
|
| 121 |
tolerance = 1e-4
|
| 122 |
+
batch_size = 32
|
| 123 |
|
| 124 |
+
print("🚀 Training FlashPack mapper model (CPU)...")
|
| 125 |
n = short_embeddings.shape[0]
|
| 126 |
for epoch in range(max_epochs):
|
| 127 |
model.train()
|