rahul7star commited on
Commit
9aeedd9
·
verified ·
1 Parent(s): 7760858

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +21 -16
app_flash.py CHANGED
@@ -70,7 +70,7 @@ def train_flashpack_model(
70
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
71
  model_name: str = "gpt2",
72
  max_length: int = 32,
73
- subset_limit: int | None = 500, # None = use full dataset
74
  push_to_hub: bool = False,
75
  hf_repo: str = "rahul7star/FlashPack",
76
  ) -> tuple:
@@ -79,14 +79,12 @@ def train_flashpack_model(
79
  print("📦 Loading dataset...")
80
  dataset = load_dataset(dataset_name, split="train")
81
 
82
- # Handle subset for quick CPU training
83
- if subset_limit is None:
84
- subset_limit = len(dataset)
 
85
 
86
- dataset = dataset.select(range(min(subset_limit, len(dataset))))
87
- print(f"⚡ Using subset: {len(dataset)} examples for quick training")
88
-
89
- # 2️⃣ Setup tokenizer and encoder
90
  tokenizer, embed_model, encode_fn = build_encoder(model_name=model_name, max_length=max_length)
91
 
92
  # 3️⃣ Encode dataset (CPU-friendly)
@@ -95,28 +93,35 @@ def train_flashpack_model(
95
  for i, item in enumerate(dataset):
96
  short_list.append(encode_fn(item["short_prompt"]))
97
  long_list.append(encode_fn(item["long_prompt"]))
98
- if (i + 1) % 50 == 0 or (i + 1) == len(dataset):
99
- print(f" → Encoded {i+1}/{len(dataset)} prompts")
 
 
 
 
 
 
 
100
  gc.collect()
101
 
102
  short_embeddings = torch.vstack(short_list)
103
  long_embeddings = torch.vstack(long_list)
 
104
 
105
- # 4️⃣ Initialize model
106
  model = GemmaTrainer(
107
  input_dim=short_embeddings.shape[1],
108
- hidden_dim=min(512, short_embeddings.shape[1]), # smaller hidden dim for speed
109
  output_dim=long_embeddings.shape[1],
110
  ).to(device)
111
 
112
- # 5️⃣ Training loop
113
  criterion = nn.MSELoss()
114
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
115
- max_epochs = 50 # fewer epochs
116
  tolerance = 1e-4
117
- batch_size = 32 # smaller batch size for CPU
118
 
119
- print("🚀 Training FlashPack mapper model (fast, CPU)...")
120
  n = short_embeddings.shape[0]
121
  for epoch in range(max_epochs):
122
  model.train()
 
70
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
71
  model_name: str = "gpt2",
72
  max_length: int = 32,
73
+ max_encode: int = 2000, # maximum number of prompts to encode
74
  push_to_hub: bool = False,
75
  hf_repo: str = "rahul7star/FlashPack",
76
  ) -> tuple:
 
79
  print("📦 Loading dataset...")
80
  dataset = load_dataset(dataset_name, split="train")
81
 
82
+ # Limit dataset to max_encode prompts
83
+ limit = min(max_encode, len(dataset))
84
+ dataset = dataset.select(range(limit))
85
+ print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
86
 
87
+ # 2️⃣ Setup tokenizer & encoder
 
 
 
88
  tokenizer, embed_model, encode_fn = build_encoder(model_name=model_name, max_length=max_length)
89
 
90
  # 3️⃣ Encode dataset (CPU-friendly)
 
93
  for i, item in enumerate(dataset):
94
  short_list.append(encode_fn(item["short_prompt"]))
95
  long_list.append(encode_fn(item["long_prompt"]))
96
+
97
+ # Exit early if we hit max_encode
98
+ if (i + 1) >= max_encode:
99
+ print(f"⚡ Reached max encode limit: {max_encode} prompts, stopping early.")
100
+ break
101
+
102
+ # Progress logging
103
+ if (i + 1) % 50 == 0:
104
+ print(f" → Encoded {i+1}/{limit} prompts")
105
  gc.collect()
106
 
107
  short_embeddings = torch.vstack(short_list)
108
  long_embeddings = torch.vstack(long_list)
109
+ print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
110
 
111
+ # 4️⃣ Initialize and train model (same as before)
112
  model = GemmaTrainer(
113
  input_dim=short_embeddings.shape[1],
114
+ hidden_dim=min(512, short_embeddings.shape[1]),
115
  output_dim=long_embeddings.shape[1],
116
  ).to(device)
117
 
 
118
  criterion = nn.MSELoss()
119
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
120
+ max_epochs = 50
121
  tolerance = 1e-4
122
+ batch_size = 32
123
 
124
+ print("🚀 Training FlashPack mapper model (CPU)...")
125
  n = short_embeddings.shape[0]
126
  for epoch in range(max_epochs):
127
  model.train()