rahul7star commited on
Commit
6c0c98e
·
verified ·
1 Parent(s): b9eee9c

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +72 -95
app_flash1.py CHANGED
@@ -9,7 +9,7 @@ from datasets import load_dataset
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository, list_repo_files, hf_hub_download
12
- from typing import Tuple
13
 
14
  # ============================================================
15
  # 🖥 Device Setup
@@ -64,43 +64,51 @@ def build_encoder(model_name="gpt2", max_length=128):
64
  return tokenizer, embed_model, encode
65
 
66
  # ============================================================
67
- # 3️⃣ Push to Hugging Face
68
  # ============================================================
69
- def push_flashpack_model_to_hf(model, hf_repo):
70
  with tempfile.TemporaryDirectory() as tmp_dir:
71
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
 
72
  model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"))
 
 
 
 
73
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
74
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
75
  repo.push_to_hub()
76
- print(f"✅ Model pushed to {hf_repo}")
77
 
78
  # ============================================================
79
- # 4️⃣ Training Logic
80
  # ============================================================
81
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
82
  hf_repo="rahul7star/FlashPack",
83
  max_encode=1000):
84
  print("📦 Loading dataset...")
85
- dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
86
- print(f"✅ Loaded {len(dataset)} samples")
 
87
 
88
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
89
 
90
  def encode_dataset(ds):
91
- s_list, l_list = [], []
92
  for i, item in enumerate(ds):
93
  s_list.append(encode_fn(item["short_prompt"]))
94
  l_list.append(encode_fn(item["long_prompt"]))
 
 
95
  if (i + 1) % 50 == 0:
96
  print(f" → Encoded {i + 1}/{len(ds)}")
97
  gc.collect()
98
- return torch.vstack(s_list), torch.vstack(l_list)
99
 
100
- short_emb, long_emb = encode_dataset(dataset)
101
- input_dim, output_dim = short_emb.shape[1], long_emb.shape[1]
102
- model = GemmaTrainer(input_dim, 1024, output_dim)
103
 
 
104
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
105
  loss_fn = nn.CosineSimilarity(dim=1)
106
 
@@ -108,122 +116,91 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
108
  for epoch in range(20):
109
  model.train()
110
  optimizer.zero_grad()
111
- preds = model(short_emb)
112
- loss = 1 - loss_fn(preds, long_emb).mean()
113
  loss.backward()
114
  optimizer.step()
115
- print(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
 
 
 
 
 
 
 
 
116
  if loss.item() < 0.01:
117
  print("🎯 Early stopping.")
118
  break
119
 
120
- push_flashpack_model_to_hf(model, hf_repo)
121
- return model, tokenizer, embed_model, dataset, long_emb
122
 
123
  # ============================================================
124
- # 5️⃣ Load or Train
125
  # ============================================================
126
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
127
  print(f"🔍 Checking for model in repo: {hf_repo}")
 
 
128
 
129
- local_path = "model.flashpack"
130
-
131
- # 1️⃣ Try local first
132
- if os.path.exists(local_path):
133
- print("✅ Found local model.flashpack — loading it directly.")
134
- model = GemmaTrainer().from_flashpack(local_path)
135
- model.eval()
136
- tokenizer, embed_model, _ = build_encoder("gpt2")
137
  else:
138
- # 2️⃣ Check HF repo
139
- try:
140
- files = list_repo_files(hf_repo)
141
- if "model.flashpack" in files:
142
- print("✅ Found model.flashpack in repo — downloading and loading it.")
143
- local_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
144
- model = GemmaTrainer().from_flashpack(local_path)
145
- model.eval()
146
- tokenizer, embed_model, _ = build_encoder("gpt2")
147
- else:
148
- print("🚫 model.flashpack not found — starting training.")
149
- return train_flashpack_model(hf_repo=hf_repo)
150
- except Exception as e:
151
- print(f"⚠️ Error checking repo: {e}")
152
- print("⏬ Training new model instead.")
153
- return train_flashpack_model(hf_repo=hf_repo)
154
-
155
- # ✅ Enhance function without dataset
156
- def enhance_fn(prompt, chat):
157
- chat = chat or []
158
- short_emb = encode_prompt(prompt, tokenizer, embed_model)
159
- mapped = model(short_emb.to(device)).cpu()
160
- # We don't need a dataset; just return the mapped tensor info as string
161
- chat.append({"role": "user", "content": prompt})
162
- chat.append({"role": "assistant", "content": f"✅ Model loaded — ready to enhance.\nOutput vector: {mapped[0].tolist()[:8]} ..."})
163
- return chat
164
-
165
- return model, tokenizer, embed_model, None, None, enhance_fn
166
-
167
-
168
-
169
- # ============================================================
170
- # 6️⃣ Encode & Enhance Functions
171
- # ============================================================
172
- @torch.no_grad()
173
- def encode_prompt(prompt, tokenizer, embed_model):
174
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
175
- padding="max_length", max_length=128).to(device)
176
- hidden = embed_model(**inputs).last_hidden_state
177
- mean_pool = hidden.mean(dim=1)
178
- max_pool, _ = hidden.max(dim=1)
179
- return torch.cat([mean_pool, max_pool], dim=1).cpu()
180
-
181
 
182
- def make_enhance_fn(model, tokenizer, embed_model, long_emb, dataset):
183
  @torch.no_grad()
184
- def fn(prompt, chat):
185
  chat = chat or []
186
- short_emb = encode_prompt(prompt, tokenizer, embed_model)
187
- mapped = model(short_emb.to(device)).cpu()
188
- sims = (long_emb @ mapped.t()).squeeze(1)
189
- best = int(sims.argmax())
190
- enhanced = dataset[best]["long_prompt"]
 
191
  chat.append({"role": "user", "content": prompt})
192
- chat.append({"role": "assistant", "content": enhanced})
193
  return chat
194
- return fn
 
195
 
196
  # ============================================================
197
- # 7️⃣ Gradio UI
198
  # ============================================================
199
  with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
200
  gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort → Long prompt expander")
201
 
202
  chatbot = gr.Chatbot(height=400, type="messages")
203
-
204
  user_input = gr.Textbox(label="Your prompt")
205
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
206
  clear_btn = gr.Button("🧹 Clear")
207
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
208
-
209
  status = gr.Markdown("Status: Ready")
210
 
211
- # Load model initially
212
- model, tokenizer, embed_model, dataset, long_emb, enhance_fn = get_flashpack_model()
213
-
214
- def enhance(prompt, chat):
215
- return enhance_fn(prompt, chat)
216
-
217
- def retrain():
218
- global model, tokenizer, embed_model, dataset, long_emb, enhance_fn
219
- model, tokenizer, embed_model, dataset, long_emb = train_flashpack_model()
220
- enhance_fn = make_enhance_fn(model, tokenizer, embed_model, long_emb, dataset)
221
- return "✅ Model retrained and pushed to HF!"
222
 
223
- send_btn.click(enhance, [user_input, chatbot], chatbot)
224
- user_input.submit(enhance, [user_input, chatbot], chatbot)
225
  clear_btn.click(lambda: [], None, chatbot)
226
- train_btn.click(retrain, None, status)
227
 
228
  if __name__ == "__main__":
229
  demo.launch(show_error=True)
 
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository, list_repo_files, hf_hub_download
12
+ import pickle
13
 
14
  # ============================================================
15
  # 🖥 Device Setup
 
64
  return tokenizer, embed_model, encode
65
 
66
  # ============================================================
67
+ # 3️⃣ Push to Hugging Face (model + mapping)
68
  # ============================================================
69
+ def push_flashpack_model_to_hf(model, short_texts, long_texts, hf_repo):
70
  with tempfile.TemporaryDirectory() as tmp_dir:
71
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
72
+ # Save model
73
  model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"))
74
+ # Save text mapping
75
+ with open(os.path.join(tmp_dir, "text_mapping.pkl"), "wb") as f:
76
+ pickle.dump({"short": short_texts, "long": long_texts}, f)
77
+ # README
78
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
79
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
80
  repo.push_to_hub()
81
+ print(f"✅ Model and text mapping pushed to {hf_repo}")
82
 
83
  # ============================================================
84
+ # 4️⃣ Training Logic (train + test splits)
85
  # ============================================================
86
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
87
  hf_repo="rahul7star/FlashPack",
88
  max_encode=1000):
89
  print("📦 Loading dataset...")
90
+ dataset_train = load_dataset(dataset_name, split="train").select(range(max_encode))
91
+ dataset_test = load_dataset(dataset_name, split="test").select(range(max_encode // 10))
92
+ print(f"✅ Loaded {len(dataset_train)} train and {len(dataset_test)} test samples")
93
 
94
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
95
 
96
  def encode_dataset(ds):
97
+ s_list, l_list, short_texts, long_texts = [], [], [], []
98
  for i, item in enumerate(ds):
99
  s_list.append(encode_fn(item["short_prompt"]))
100
  l_list.append(encode_fn(item["long_prompt"]))
101
+ short_texts.append(item["short_prompt"])
102
+ long_texts.append(item["long_prompt"])
103
  if (i + 1) % 50 == 0:
104
  print(f" → Encoded {i + 1}/{len(ds)}")
105
  gc.collect()
106
+ return torch.vstack(s_list), torch.vstack(l_list), short_texts, long_texts
107
 
108
+ short_emb_train, long_emb_train, short_texts_train, long_texts_train = encode_dataset(dataset_train)
109
+ short_emb_test, long_emb_test, _, _ = encode_dataset(dataset_test)
 
110
 
111
+ model = GemmaTrainer()
112
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
113
  loss_fn = nn.CosineSimilarity(dim=1)
114
 
 
116
  for epoch in range(20):
117
  model.train()
118
  optimizer.zero_grad()
119
+ preds = model(short_emb_train)
120
+ loss = 1 - loss_fn(preds, long_emb_train).mean()
121
  loss.backward()
122
  optimizer.step()
123
+ print(f"Epoch {epoch+1}/20 | Train Loss: {loss.item():.5f}")
124
+
125
+ # Evaluate on test
126
+ model.eval()
127
+ with torch.no_grad():
128
+ test_preds = model(short_emb_test)
129
+ test_loss = 1 - loss_fn(test_preds, long_emb_test).mean()
130
+ print(f" | Test Loss: {test_loss.item():.5f}")
131
+
132
  if loss.item() < 0.01:
133
  print("🎯 Early stopping.")
134
  break
135
 
136
+ push_flashpack_model_to_hf(model, short_texts_train, long_texts_train, hf_repo)
137
+ return model, tokenizer, embed_model, short_emb_train, long_emb_train, short_texts_train, long_texts_train
138
 
139
  # ============================================================
140
+ # 5️⃣ Load pretrained model for query
141
  # ============================================================
142
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
143
  print(f"🔍 Checking for model in repo: {hf_repo}")
144
+ local_model_path = "model.flashpack"
145
+ local_mapping_path = "text_mapping.pkl"
146
 
147
+ if os.path.exists(local_model_path) and os.path.exists(local_mapping_path):
148
+ print("✅ Loading local model and mapping")
 
 
 
 
 
 
149
  else:
150
+ files = list_repo_files(hf_repo)
151
+ if "model.flashpack" in files:
152
+ print("✅ Downloading model from HF")
153
+ local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
154
+ if "text_mapping.pkl" in files:
155
+ print("✅ Downloading text mapping from HF")
156
+ local_mapping_path = hf_hub_download(repo_id=hf_repo, filename="text_mapping.pkl")
157
+
158
+ # Load model
159
+ model = GemmaTrainer().from_flashpack(local_model_path)
160
+ model.eval()
161
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2")
162
+ # Load mapping
163
+ with open(local_mapping_path, "rb") as f:
164
+ mapping = pickle.load(f)
165
+ short_texts, long_texts = mapping["short"], mapping["long"]
166
+ short_embs = torch.vstack([encode_fn(s) for s in short_texts])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ # Enhance function
169
  @torch.no_grad()
170
+ def enhance_fn(prompt, chat):
171
  chat = chat or []
172
+ query_emb = encode_fn(prompt)
173
+ mapped = model(query_emb.to(device)).cpu()
174
+ # Compute cosine similarity to all stored long embeddings
175
+ sims = torch.nn.functional.cosine_similarity(mapped, short_embs)
176
+ best_idx = int(sims.argmax())
177
+ best_long_prompt = long_texts[best_idx]
178
  chat.append({"role": "user", "content": prompt})
179
+ chat.append({"role": "assistant", "content": best_long_prompt})
180
  return chat
181
+
182
+ return model, tokenizer, embed_model, enhance_fn
183
 
184
  # ============================================================
185
+ # 6️⃣ Gradio UI
186
  # ============================================================
187
  with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
188
  gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort → Long prompt expander")
189
 
190
  chatbot = gr.Chatbot(height=400, type="messages")
 
191
  user_input = gr.Textbox(label="Your prompt")
192
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
193
  clear_btn = gr.Button("🧹 Clear")
194
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
 
195
  status = gr.Markdown("Status: Ready")
196
 
197
+ # Load pretrained model
198
+ model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
 
 
 
 
 
 
 
 
 
199
 
200
+ send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
201
+ user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
202
  clear_btn.click(lambda: [], None, chatbot)
203
+ train_btn.click(lambda: train_flashpack_model(), None, status)
204
 
205
  if __name__ == "__main__":
206
  demo.launch(show_error=True)