rahul7star commited on
Commit
5ee9a29
·
verified ·
1 Parent(s): 8143e5c

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +54 -64
app_flash1.py CHANGED
@@ -9,25 +9,20 @@ from datasets import load_dataset
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository, list_repo_files, hf_hub_download
12
- import pickle
13
 
14
- # ============================================================
15
- # 🖥 Device Setup
16
- # ============================================================
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"🔧 Using device: {device} (CPU-only mode)")
20
 
21
- # ============================================================
22
- # 1️⃣ Fixed Model Definition (FlashPack-compatible)
23
- # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
  def __init__(self):
26
  super().__init__()
27
  input_dim = 1536
28
  hidden_dim = 1024
29
  output_dim = 1536
30
-
31
  self.fc1 = nn.Linear(input_dim, hidden_dim)
32
  self.relu = nn.ReLU()
33
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
@@ -41,11 +36,10 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
41
  x = self.fc3(x)
42
  return x
43
 
44
- # ============================================================
45
- # 2️⃣ Encoder Setup
46
- # ============================================================
47
  def build_encoder(model_name="gpt2", max_length=128):
48
- print(f"📦 Loading encoder: {model_name}")
49
  tokenizer = AutoTokenizer.from_pretrained(model_name)
50
  if tokenizer.pad_token is None:
51
  tokenizer.pad_token = tokenizer.eos_token
@@ -60,54 +54,42 @@ def build_encoder(model_name="gpt2", max_length=128):
60
  mean_pool = hidden.mean(dim=1)
61
  max_pool, _ = hidden.max(dim=1)
62
  return torch.cat([mean_pool, max_pool], dim=1).cpu()
63
-
64
  return tokenizer, embed_model, encode
65
 
66
- # ============================================================
67
- # 3️⃣ Push to Hugging Face (model + mapping)
68
- # ============================================================
69
- def push_flashpack_model_to_hf(model, short_texts, long_texts, hf_repo):
70
  with tempfile.TemporaryDirectory() as tmp_dir:
71
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
72
- # Save model
73
  model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"))
74
- # Save text mapping
75
- with open(os.path.join(tmp_dir, "text_mapping.pkl"), "wb") as f:
76
- pickle.dump({"short": short_texts, "long": long_texts}, f)
77
- # README
78
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
79
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
80
  repo.push_to_hub()
81
- print(f"✅ Model and text mapping pushed to {hf_repo}")
82
 
83
- # ============================================================
84
- # 4️⃣ Training Logic (train + test splits)
85
- # ============================================================
86
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
87
  hf_repo="rahul7star/FlashPack",
88
  max_encode=1000):
89
  print("📦 Loading dataset...")
90
- dataset_train = load_dataset(dataset_name, split="train").select(range(max_encode))
91
- dataset_test = load_dataset(dataset_name, split="test").select(range(max_encode // 10))
92
- print(f"✅ Loaded {len(dataset_train)} train and {len(dataset_test)} test samples")
93
-
94
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
95
 
96
  def encode_dataset(ds):
97
- s_list, l_list, short_texts, long_texts = [], [], [], []
98
  for i, item in enumerate(ds):
99
  s_list.append(encode_fn(item["short_prompt"]))
100
  l_list.append(encode_fn(item["long_prompt"]))
101
- short_texts.append(item["short_prompt"])
102
- long_texts.append(item["long_prompt"])
103
  if (i + 1) % 50 == 0:
104
  print(f" → Encoded {i + 1}/{len(ds)}")
105
  gc.collect()
106
- return torch.vstack(s_list), torch.vstack(l_list), short_texts, long_texts
107
-
108
- short_emb_train, long_emb_train, short_texts_train, long_texts_train = encode_dataset(dataset_train)
109
- short_emb_test, long_emb_test, _, _ = encode_dataset(dataset_test)
110
 
 
111
  model = GemmaTrainer()
112
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
113
  loss_fn = nn.CosineSimilarity(dim=1)
@@ -116,38 +98,40 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
116
  for epoch in range(20):
117
  model.train()
118
  optimizer.zero_grad()
119
- preds = model(short_emb_train)
120
- loss = 1 - loss_fn(preds, long_emb_train).mean()
121
  loss.backward()
122
  optimizer.step()
123
- print(f"Epoch {epoch+1}/20 | Train Loss: {loss.item():.5f}")
124
-
125
- # Evaluate on test
126
- model.eval()
127
- with torch.no_grad():
128
- test_preds = model(short_emb_test)
129
- test_loss = 1 - loss_fn(test_preds, long_emb_test).mean()
130
- print(f" | Test Loss: {test_loss.item():.5f}")
131
-
132
  if loss.item() < 0.01:
133
  print("🎯 Early stopping.")
134
  break
135
 
136
- push_flashpack_model_to_hf(model, short_texts_train, long_texts_train, hf_repo)
137
- return model, tokenizer, embed_model, short_emb_train, long_emb_train, short_texts_train, long_texts_train
138
 
139
- # ============================================================
140
- # 5️⃣ Load pretrained model for query
141
- # ============================================================
142
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
143
- print(f"🔍 Checking for model in repo: {hf_repo}")
144
  local_model_path = "model.flashpack"
145
 
 
146
  if os.path.exists(local_model_path):
147
  print("✅ Loading local model")
148
  else:
149
- print("✅ Downloading model from HF")
150
- local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
 
 
 
 
 
 
 
 
 
 
151
 
152
  model = GemmaTrainer().from_flashpack(local_model_path)
153
  model.eval()
@@ -158,18 +142,17 @@ def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
158
  chat = chat or []
159
  short_emb = encode_fn(prompt)
160
  mapped = model(short_emb.to(device)).cpu()
161
- # convert mapped tensor into a string (this can be learned in training)
162
- # For demonstration, we just return a placeholder
163
- long_prompt = f"Enhanced long prompt for: {prompt}" # replace with your model's actual decoding if available
164
  chat.append({"role": "user", "content": prompt})
165
  chat.append({"role": "assistant", "content": long_prompt})
166
  return chat
167
 
168
  return model, tokenizer, embed_model, enhance_fn
169
 
170
- # ============================================================
171
- # 6️⃣ Gradio UI
172
- # ============================================================
173
  with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
174
  gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort → Long prompt expander")
175
 
@@ -180,13 +163,20 @@ with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
180
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
181
  status = gr.Markdown("Status: Ready")
182
 
183
- # Load pretrained model
184
  model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
185
 
186
  send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
187
  user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
188
  clear_btn.click(lambda: [], None, chatbot)
189
- train_btn.click(lambda: train_flashpack_model(), None, status)
 
 
 
 
 
 
 
190
 
191
  if __name__ == "__main__":
192
  demo.launch(show_error=True)
 
9
  from transformers import AutoTokenizer, AutoModel
10
  from flashpack import FlashPackMixin
11
  from huggingface_hub import Repository, list_repo_files, hf_hub_download
 
12
 
 
 
 
13
  device = torch.device("cpu")
14
  torch.set_num_threads(4)
15
  print(f"🔧 Using device: {device} (CPU-only mode)")
16
 
17
+ # ===========================
18
+ # Model Definition
19
+ # ===========================
20
  class GemmaTrainer(nn.Module, FlashPackMixin):
21
  def __init__(self):
22
  super().__init__()
23
  input_dim = 1536
24
  hidden_dim = 1024
25
  output_dim = 1536
 
26
  self.fc1 = nn.Linear(input_dim, hidden_dim)
27
  self.relu = nn.ReLU()
28
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
 
36
  x = self.fc3(x)
37
  return x
38
 
39
+ # ===========================
40
+ # Encoder
41
+ # ===========================
42
  def build_encoder(model_name="gpt2", max_length=128):
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_name)
44
  if tokenizer.pad_token is None:
45
  tokenizer.pad_token = tokenizer.eos_token
 
54
  mean_pool = hidden.mean(dim=1)
55
  max_pool, _ = hidden.max(dim=1)
56
  return torch.cat([mean_pool, max_pool], dim=1).cpu()
57
+
58
  return tokenizer, embed_model, encode
59
 
60
+ # ===========================
61
+ # Push model to HF
62
+ # ===========================
63
+ def push_flashpack_model_to_hf(model, hf_repo):
64
  with tempfile.TemporaryDirectory() as tmp_dir:
65
  repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True)
 
66
  model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"))
 
 
 
 
67
  with open(os.path.join(tmp_dir, "README.md"), "w") as f:
68
  f.write("# FlashPack Model\nTrained locally and pushed to HF.")
69
  repo.push_to_hub()
70
+ print(f"✅ Model pushed to {hf_repo}")
71
 
72
+ # ===========================
73
+ # Training
74
+ # ===========================
75
  def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
76
  hf_repo="rahul7star/FlashPack",
77
  max_encode=1000):
78
  print("📦 Loading dataset...")
79
+ dataset = load_dataset(dataset_name, split="train").select(range(max_encode))
 
 
 
80
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
81
 
82
  def encode_dataset(ds):
83
+ s_list, l_list = [], []
84
  for i, item in enumerate(ds):
85
  s_list.append(encode_fn(item["short_prompt"]))
86
  l_list.append(encode_fn(item["long_prompt"]))
 
 
87
  if (i + 1) % 50 == 0:
88
  print(f" → Encoded {i + 1}/{len(ds)}")
89
  gc.collect()
90
+ return torch.vstack(s_list), torch.vstack(l_list)
 
 
 
91
 
92
+ short_emb, long_emb = encode_dataset(dataset)
93
  model = GemmaTrainer()
94
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
95
  loss_fn = nn.CosineSimilarity(dim=1)
 
98
  for epoch in range(20):
99
  model.train()
100
  optimizer.zero_grad()
101
+ preds = model(short_emb)
102
+ loss = 1 - loss_fn(preds, long_emb).mean()
103
  loss.backward()
104
  optimizer.step()
105
+ print(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}")
 
 
 
 
 
 
 
 
106
  if loss.item() < 0.01:
107
  print("🎯 Early stopping.")
108
  break
109
 
110
+ push_flashpack_model_to_hf(model, hf_repo)
111
+ return model, tokenizer, embed_model
112
 
113
+ # ===========================
114
+ # Load or Train
115
+ # ===========================
116
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
 
117
  local_model_path = "model.flashpack"
118
 
119
+ # 1. Try local
120
  if os.path.exists(local_model_path):
121
  print("✅ Loading local model")
122
  else:
123
+ # 2. Try HF
124
+ try:
125
+ files = list_repo_files(hf_repo)
126
+ if "model.flashpack" in files:
127
+ print("✅ Downloading model from HF")
128
+ local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
129
+ else:
130
+ print("🚫 Model not found on HF — will train a new model")
131
+ return train_flashpack_model(hf_repo=hf_repo)
132
+ except Exception as e:
133
+ print(f"⚠️ Error accessing HF: {e}. Training new model instead.")
134
+ return train_flashpack_model(hf_repo=hf_repo)
135
 
136
  model = GemmaTrainer().from_flashpack(local_model_path)
137
  model.eval()
 
142
  chat = chat or []
143
  short_emb = encode_fn(prompt)
144
  mapped = model(short_emb.to(device)).cpu()
145
+ # Simply return a placeholder text for demonstration
146
+ long_prompt = f"✅ Enhanced long prompt for: {prompt}"
 
147
  chat.append({"role": "user", "content": prompt})
148
  chat.append({"role": "assistant", "content": long_prompt})
149
  return chat
150
 
151
  return model, tokenizer, embed_model, enhance_fn
152
 
153
+ # ===========================
154
+ # Gradio UI
155
+ # ===========================
156
  with gr.Blocks(title="✨ FlashPack Prompt Enhancer") as demo:
157
  gr.Markdown("## 🧠 FlashPack Prompt Enhancer (CPU)\nShort → Long prompt expander")
158
 
 
163
  train_btn = gr.Button("🧩 Train Model", variant="secondary")
164
  status = gr.Markdown("Status: Ready")
165
 
166
+ # Load or train model
167
  model, tokenizer, embed_model, enhance_fn = get_flashpack_model()
168
 
169
  send_btn.click(enhance_fn, [user_input, chatbot], chatbot)
170
  user_input.submit(enhance_fn, [user_input, chatbot], chatbot)
171
  clear_btn.click(lambda: [], None, chatbot)
172
+
173
+ def retrain():
174
+ global model, tokenizer, embed_model, enhance_fn
175
+ model, tokenizer, embed_model = train_flashpack_model()
176
+ enhance_fn = get_flashpack_model()[3]
177
+ return "✅ Model retrained and pushed to HF!"
178
+
179
+ train_btn.click(retrain, None, status)
180
 
181
  if __name__ == "__main__":
182
  demo.launch(show_error=True)