rahul7star commited on
Commit
3589a82
·
verified ·
1 Parent(s): 390f3c3

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +96 -128
app_flash1.py CHANGED
@@ -10,17 +10,19 @@ from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository
12
  from typing import Tuple
13
- from sklearn.model_selection import train_test_split
14
 
 
 
 
15
  device = torch.device("cpu")
16
  torch.set_num_threads(4)
17
- print(f"🔧 Using device: {device} (CPU-only)")
18
 
19
  # ============================================================
20
- # 1️⃣ Model
21
  # ============================================================
22
  class GemmaTrainer(nn.Module, FlashPackMixin):
23
- def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536):
24
  super().__init__()
25
  self.fc1 = nn.Linear(input_dim, hidden_dim)
26
  self.relu = nn.ReLU()
@@ -36,187 +38,160 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
36
  return x
37
 
38
  # ============================================================
39
- # 2️⃣ Encoder with batch mean+max pooling
40
  # ============================================================
41
- def build_encoder(model_name="gpt2", max_length=128):
42
- tokenizer = AutoTokenizer.from_pretrained(model_name)
43
- if tokenizer.pad_token is None:
44
- tokenizer.pad_token = tokenizer.eos_token
45
-
46
- embed_model = AutoModel.from_pretrained(model_name).to(device)
47
- embed_model.eval()
48
 
49
- @torch.no_grad()
50
- def encode_batch(prompts: list, batch_size=16) -> torch.Tensor:
51
- embeddings = []
52
- for i in range(0, len(prompts), batch_size):
53
- batch = prompts[i:i+batch_size]
54
- inputs = tokenizer(batch, return_tensors="pt", truncation=True,
55
- padding="max_length", max_length=max_length).to(device)
56
- last_hidden = embed_model(**inputs).last_hidden_state
57
- mean_pool = last_hidden.mean(dim=1)
58
- max_pool, _ = last_hidden.max(dim=1)
59
- batch_emb = torch.cat([mean_pool, max_pool], dim=1)
60
- embeddings.append(batch_emb.cpu())
61
- return torch.vstack(embeddings)
62
 
63
- return tokenizer, embed_model, encode_batch
 
 
 
 
 
 
 
 
64
 
65
  # ============================================================
66
- # 3️⃣ Push model to HF
67
  # ============================================================
68
  def push_flashpack_model_to_hf(model, hf_repo: str):
69
  logs = []
70
  with tempfile.TemporaryDirectory() as tmp_dir:
71
- logs.append(f"📂 Using temporary directory: {tmp_dir}")
72
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
73
  pack_path = os.path.join(tmp_dir, "model.flashpack")
74
  model.save_flashpack(pack_path, target_dtype=torch.float32)
75
  readme_path = os.path.join(tmp_dir, "README.md")
76
  with open(readme_path, "w") as f:
77
- f.write("# FlashPack Model\nThis repo contains a FlashPack model trained for short→long prompt mapping.")
78
  repo.push_to_hub()
79
- logs.append(f"✅ Model pushed to Hugging Face repo: {hf_repo}")
80
  return logs
81
 
82
  # ============================================================
83
- # 4️⃣ Train with train/test split & detailed logging
84
  # ============================================================
85
  def train_flashpack_model(
86
- dataset_name="rahul7star/prompt-enhancer-dataset",
87
- max_encode=1000,
88
- hidden_dim=1024,
89
- hf_repo="rahul7star/FlashPack",
90
- push_to_hub=True,
91
- test_split=0.1,
92
- batch_size=32,
93
- max_epochs=50,
94
- target_test_loss=0.01
95
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
96
-
97
  print("📦 Loading dataset...")
98
  dataset = load_dataset(dataset_name, split="train")
99
- limit = min(max_encode, len(dataset))
100
- dataset = dataset.select(range(limit))
101
- print(f"⚡ Using {len(dataset)} prompts for training")
102
-
103
- short_prompts = [item["short_prompt"] for item in dataset]
104
- long_prompts = [item["long_prompt"] for item in dataset]
105
-
106
- # Split
107
- train_short, test_short, train_long, test_long = train_test_split(
108
- short_prompts, long_prompts, test_size=test_split, random_state=42
109
- )
110
- print(f"🔹 Train size: {len(train_short)}, Test size: {len(test_short)}")
111
-
112
- tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128)
113
-
114
- # Encode
115
- print("⚡ Encoding training prompts...")
116
- train_short_emb = encode_batch(train_short)
117
- train_long_emb = encode_batch(train_long)
118
- print(f"✅ Train embeddings shape: {train_short_emb.shape}, {train_long_emb.shape}")
119
-
120
- print("⚡ Encoding test prompts...")
121
- test_short_emb = encode_batch(test_short)
122
- test_long_emb = encode_batch(test_long)
123
- print(f"✅ Test embeddings shape: {test_short_emb.shape}, {test_long_emb.shape}")
124
-
125
- input_dim = train_short_emb.shape[1]
126
- output_dim = train_long_emb.shape[1]
127
-
128
  model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
129
 
130
  criterion = nn.CosineSimilarity(dim=1)
131
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
132
-
133
- n_train = train_short_emb.shape[0]
134
 
135
  print("🚀 Training model...")
 
136
  for epoch in range(max_epochs):
137
  model.train()
 
138
  epoch_loss = 0.0
139
- perm = torch.randperm(n_train)
140
- for start in range(0, n_train, batch_size):
141
  idx = perm[start:start+batch_size]
142
- inputs = train_short_emb[idx].to(device)
143
- targets = train_long_emb[idx].to(device)
144
-
145
  optimizer.zero_grad()
146
  outputs = model(inputs)
147
  loss = 1 - criterion(outputs, targets).mean()
148
  loss.backward()
149
  optimizer.step()
150
  epoch_loss += loss.item() * inputs.size(0)
 
151
 
152
- epoch_loss /= n_train
153
-
154
- # Evaluate on test
155
  model.eval()
156
  with torch.no_grad():
157
- test_outputs = model(test_short_emb.to(device))
158
- test_loss = (1 - criterion(test_outputs, test_long_emb.to(device)).mean()).item()
159
 
160
- print(f"Epoch {epoch+1}/{max_epochs} Train loss: {epoch_loss:.6f}, Test loss: {test_loss:.6f}")
161
 
162
- # Check if model is perfect enough
163
- if test_loss <= target_test_loss:
164
- print(f" Target test loss reached ({test_loss:.6f}) stopping training early.")
165
  break
166
 
167
- # Push to HF if trained well
168
- logs = []
169
- if push_to_hub and test_loss <= target_test_loss:
170
  logs = push_flashpack_model_to_hf(model, hf_repo)
171
  for log in logs:
172
  print(log)
173
- elif push_to_hub:
174
- print(f"⚠️ Test loss too high ({test_loss:.6f}); skipping HF upload.")
175
 
176
- return model, dataset, embed_model, tokenizer, train_long_emb
177
 
178
  # ============================================================
179
- # 5️⃣ Load or train
180
  # ============================================================
181
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
182
  try:
183
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
184
  model = GemmaTrainer.from_flashpack(hf_repo)
185
  model.eval()
186
- tokenizer, embed_model, encode_batch = build_encoder("gpt2", max_length=128)
187
- return model, tokenizer, embed_model
188
  except Exception as e:
189
  print(f"⚠️ Load failed: {e}")
190
- print("⏬ Training a new FlashPack model locally...")
191
- return train_flashpack_model(hf_repo=hf_repo)
 
192
 
193
  # ============================================================
194
- # 6️⃣ Load or train
195
- # ============================================================
196
- model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model()
197
-
198
- # ============================================================
199
- # 7️⃣ Inference helpers
200
  # ============================================================
201
  @torch.no_grad()
202
- def encode_for_inference(prompt: str) -> torch.Tensor:
203
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
204
- padding="max_length", max_length=128).to(device)
205
- last_hidden = embed_model(**inputs).last_hidden_state
206
- mean_pool = last_hidden.mean(dim=1)
207
- max_pool, _ = last_hidden.max(dim=1)
208
- return torch.cat([mean_pool, max_pool], dim=1).cpu()
209
-
210
- def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
211
  chat_history = chat_history or []
212
- short_emb = encode_for_inference(user_prompt)
213
  mapped = model(short_emb.to(device)).cpu()
214
-
215
  sims = (long_embeddings @ mapped.t()).squeeze(1)
216
  long_norms = long_embeddings.norm(dim=1)
217
  mapped_norm = mapped.norm()
218
  sims = sims / (long_norms * (mapped_norm + 1e-12))
219
-
220
  best_idx = int(sims.argmax().item())
221
  enhanced_prompt = dataset[best_idx]["long_prompt"]
222
 
@@ -225,28 +200,21 @@ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_h
225
  return chat_history
226
 
227
  # ============================================================
228
- # 8️⃣ Gradio UI
229
  # ============================================================
230
- with gr.Blocks(title="Prompt Enhancer FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
231
- gr.Markdown(
232
- """
233
- # ✨ Prompt Enhancer (FlashPack mapper)
234
- Enter a short prompt, and the model will **expand it with details and creative context**.
235
- (CPU-only mode.)
236
- """
237
- )
238
 
 
 
239
  with gr.Row():
240
  chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
241
  with gr.Column(scale=1):
242
  user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
243
- temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
244
- max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
245
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
246
  clear_btn = gr.Button("🧹 Clear Chat")
247
 
248
- send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
249
- user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
250
  clear_btn.click(lambda: [], None, chatbot)
251
 
252
  if __name__ == "__main__":
 
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository
12
  from typing import Tuple
 
13
 
14
+ # ============================================================
15
+ # 🖥 CPU device setup
16
+ # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
+ print(f"🔧 Using device: {device} (CPU-only mode)")
20
 
21
  # ============================================================
22
+ # 1️⃣ FlashPack MLP model (CPU-friendly)
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
+ def __init__(self, input_dim: int, hidden_dim: int = 512, output_dim: int = 768):
26
  super().__init__()
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
 
38
  return x
39
 
40
  # ============================================================
41
+ # 2️⃣ Lazy-loading GPT-2 encoder
42
  # ============================================================
43
+ _embed_model = None
44
+ _tokenizer = None
 
 
 
 
 
45
 
46
+ def get_encoder(model_name="gpt2", max_length=64):
47
+ global _embed_model, _tokenizer
48
+ if _embed_model is None or _tokenizer is None:
49
+ print("⚡ Loading GPT-2 encoder model...")
50
+ _tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ if _tokenizer.pad_token is None:
52
+ _tokenizer.pad_token = _tokenizer.eos_token
53
+ _embed_model = AutoModel.from_pretrained(model_name).to(device)
54
+ _embed_model.eval()
55
+ return _tokenizer, _embed_model
 
 
 
56
 
57
+ @torch.no_grad()
58
+ def encode_prompt(prompt: str) -> torch.Tensor:
59
+ tokenizer, embed_model = get_encoder()
60
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
61
+ padding="max_length", max_length=64).to(device)
62
+ last_hidden = embed_model(**inputs).last_hidden_state
63
+ mean_pool = last_hidden.mean(dim=1)
64
+ max_pool, _ = last_hidden.max(dim=1)
65
+ return torch.cat([mean_pool, max_pool], dim=1).cpu()
66
 
67
  # ============================================================
68
+ # 3️⃣ Push FlashPack model to Hugging Face Hub
69
  # ============================================================
70
  def push_flashpack_model_to_hf(model, hf_repo: str):
71
  logs = []
72
  with tempfile.TemporaryDirectory() as tmp_dir:
73
+ logs.append(f"📂 Using temp dir: {tmp_dir}")
74
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
75
  pack_path = os.path.join(tmp_dir, "model.flashpack")
76
  model.save_flashpack(pack_path, target_dtype=torch.float32)
77
  readme_path = os.path.join(tmp_dir, "README.md")
78
  with open(readme_path, "w") as f:
79
+ f.write("# FlashPack Model\nThis repo contains a FlashPack model.")
80
  repo.push_to_hub()
81
+ logs.append(f"✅ Model pushed to HF: {hf_repo}")
82
  return logs
83
 
84
  # ============================================================
85
+ # 4️⃣ Train FlashPack model with train/test split
86
  # ============================================================
87
  def train_flashpack_model(
88
+ dataset_name: str = "rahul7star/prompt-enhancer-dataset",
89
+ max_encode: int = 500,
90
+ hidden_dim: int = 512,
91
+ push_to_hub: bool = True,
92
+ hf_repo: str = "rahul7star/FlashPack",
93
+ early_stop_threshold: float = 0.001
 
 
 
94
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
95
+
96
  print("📦 Loading dataset...")
97
  dataset = load_dataset(dataset_name, split="train")
98
+ dataset = dataset.select(range(min(max_encode, len(dataset))))
99
+ n_train = int(0.8 * len(dataset))
100
+ n_test = len(dataset) - n_train
101
+ train_dataset = dataset.select(range(n_train))
102
+ test_dataset = dataset.select(range(n_train, len(dataset)))
103
+ print(f" Train: {n_train}, Test: {n_test}")
104
+
105
+ # Encode prompts lazily
106
+ def batch_encode(ds):
107
+ short_list, long_list = [], []
108
+ for i, item in enumerate(ds):
109
+ short_list.append(encode_prompt(item["short_prompt"]))
110
+ long_list.append(encode_prompt(item["long_prompt"]))
111
+ if (i+1) % 20 == 0 or (i+1) == len(ds):
112
+ print(f" → Encoded {i+1}/{len(ds)} prompts")
113
+ gc.collect()
114
+ return torch.vstack(short_list), torch.vstack(long_list)
115
+
116
+ short_train, long_train = batch_encode(train_dataset)
117
+ short_test, long_test = batch_encode(test_dataset)
118
+ print(f"✅ Embeddings shapes: short_train={short_train.shape}, long_train={long_train.shape}")
119
+
120
+ input_dim = short_train.shape[1]
121
+ output_dim = long_train.shape[1]
 
 
 
 
 
122
  model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
123
 
124
  criterion = nn.CosineSimilarity(dim=1)
125
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
126
+ max_epochs = 50
127
+ batch_size = 16
128
 
129
  print("🚀 Training model...")
130
+ n = short_train.shape[0]
131
  for epoch in range(max_epochs):
132
  model.train()
133
+ perm = torch.randperm(n)
134
  epoch_loss = 0.0
135
+ for start in range(0, n, batch_size):
 
136
  idx = perm[start:start+batch_size]
137
+ inputs = short_train[idx].to(device)
138
+ targets = long_train[idx].to(device)
 
139
  optimizer.zero_grad()
140
  outputs = model(inputs)
141
  loss = 1 - criterion(outputs, targets).mean()
142
  loss.backward()
143
  optimizer.step()
144
  epoch_loss += loss.item() * inputs.size(0)
145
+ epoch_loss /= n
146
 
147
+ # Evaluate on test set
 
 
148
  model.eval()
149
  with torch.no_grad():
150
+ outputs_test = model(short_test.to(device))
151
+ test_loss = 1 - criterion(outputs_test, long_test.to(device)).mean().item()
152
 
153
+ print(f"Epoch {epoch+1}/{max_epochs} | Train Loss={epoch_loss:.6f} | Test Loss={test_loss:.6f}")
154
 
155
+ # Early stop: very low test loss means model is good
156
+ if test_loss < early_stop_threshold:
157
+ print("🎯 Early stop: test loss below threshold. Model is ready!")
158
  break
159
 
160
+ if push_to_hub:
 
 
161
  logs = push_flashpack_model_to_hf(model, hf_repo)
162
  for log in logs:
163
  print(log)
 
 
164
 
165
+ return model, dataset, None, None, long_train # embed_model and tokenizer lazy-loaded
166
 
167
  # ============================================================
168
+ # 5️⃣ Lazy load or train
169
  # ============================================================
170
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
171
  try:
172
  print(f"🔁 Attempting to load FlashPack model from {hf_repo}")
173
  model = GemmaTrainer.from_flashpack(hf_repo)
174
  model.eval()
175
+ print("✅ Loaded model from HF")
176
+ return model
177
  except Exception as e:
178
  print(f"⚠️ Load failed: {e}")
179
+ print("⏬ Training new FlashPack model locally...")
180
+ model, dataset, _, _, long_embeddings = train_flashpack_model()
181
+ return model, dataset, long_embeddings
182
 
183
  # ============================================================
184
+ # 6️⃣ Inference helpers
 
 
 
 
 
185
  # ============================================================
186
  @torch.no_grad()
187
+ def enhance_prompt(user_prompt: str, chat_history, model, long_embeddings, dataset):
 
 
 
 
 
 
 
 
188
  chat_history = chat_history or []
189
+ short_emb = encode_prompt(user_prompt)
190
  mapped = model(short_emb.to(device)).cpu()
 
191
  sims = (long_embeddings @ mapped.t()).squeeze(1)
192
  long_norms = long_embeddings.norm(dim=1)
193
  mapped_norm = mapped.norm()
194
  sims = sims / (long_norms * (mapped_norm + 1e-12))
 
195
  best_idx = int(sims.argmax().item())
196
  enhanced_prompt = dataset[best_idx]["long_prompt"]
197
 
 
200
  return chat_history
201
 
202
  # ============================================================
203
+ # 7️⃣ Launch Gradio app
204
  # ============================================================
205
+ model, dataset, long_embeddings = get_flashpack_model()
 
 
 
 
 
 
 
206
 
207
+ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
208
+ gr.Markdown("# ✨ Prompt Enhancer (FlashPack mapper)\nEnter a short prompt, and it will expand it.")
209
  with gr.Row():
210
  chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
211
  with gr.Column(scale=1):
212
  user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
 
 
213
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
214
  clear_btn = gr.Button("🧹 Clear Chat")
215
 
216
+ send_btn.click(enhance_prompt, [user_prompt, chatbot, model, long_embeddings, dataset], chatbot)
217
+ user_prompt.submit(enhance_prompt, [user_prompt, chatbot, model, long_embeddings, dataset], chatbot)
218
  clear_btn.click(lambda: [], None, chatbot)
219
 
220
  if __name__ == "__main__":