rahul7star commited on
Commit
9b6142b
·
verified ·
1 Parent(s): d191426

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +27 -25
app_flash.py CHANGED
@@ -1,4 +1,4 @@
1
- # prompt_enhancer_flashpack_cpu.py
2
  import gc
3
  import torch
4
  import torch.nn as nn
@@ -13,14 +13,14 @@ from typing import Tuple
13
  # 🖥 Force CPU mode
14
  # ============================================================
15
  device = torch.device("cpu")
16
- torch.set_num_threads(4) # reduce CPU contention
17
  print(f"🔧 Forcing device: {device} (CPU-only mode)")
18
 
19
  # ============================================================
20
  # 1️⃣ Define FlashPack model
21
  # ============================================================
22
  class GemmaTrainer(nn.Module, FlashPackMixin):
23
- def __init__(self, input_dim: int = 768, hidden_dim: int = 1024, output_dim: int = 768):
24
  super().__init__()
25
  self.fc1 = nn.Linear(input_dim, hidden_dim)
26
  self.relu = nn.ReLU()
@@ -58,30 +58,23 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
58
  return tokenizer, embed_model, encode
59
 
60
  # ============================================================
61
- # 3️⃣ Train FlashPack mapping (CPU-optimized)
62
  # ============================================================
63
- def train_flashpack_model(
64
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
65
- model_name: str = "gpt2",
66
- max_length: int = 32,
67
- max_encode: int = 1000, # use smaller number for CPU
68
- push_to_hub: bool = False,
69
  hf_repo: str = "rahul7star/FlashPack",
 
 
70
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
71
 
72
- # 1️⃣ Load dataset
73
  print("📦 Loading dataset...")
74
  dataset = load_dataset(dataset_name, split="train")
75
-
76
- # Limit dataset to max_encode prompts
77
  limit = min(max_encode, len(dataset))
78
  dataset = dataset.select(range(limit))
79
  print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
80
 
81
- # 2️⃣ Setup encoder
82
- tokenizer, embed_model, encode_fn = build_encoder(model_name, max_length)
83
 
84
- # 3️⃣ Encode dataset
85
  print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
86
  short_list, long_list = [], []
87
  for i, item in enumerate(dataset):
@@ -96,7 +89,6 @@ def train_flashpack_model(
96
  long_embeddings = torch.vstack(long_list)
97
  print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
98
 
99
- # 4️⃣ Initialize & train model
100
  model = GemmaTrainer(
101
  input_dim=short_embeddings.shape[1],
102
  hidden_dim=min(512, short_embeddings.shape[1]),
@@ -132,25 +124,35 @@ def train_flashpack_model(
132
 
133
  print("✅ Training finished!")
134
 
135
- # 5️⃣ Push to HF repo if requested
136
  if push_to_hub:
 
137
  model.save_flashpack(hf_repo, target_dtype=torch.float32, push_to_hub=True)
138
  print(f"✅ Model pushed to HF repo: {hf_repo}")
139
 
140
  return model, dataset, embed_model, tokenizer, long_embeddings
141
 
142
  # ============================================================
143
- # 4️⃣ Run training & load model
 
 
 
 
 
 
 
 
 
144
  # ============================================================
145
- model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model(
146
- max_encode=1000, # safe CPU-friendly subset
147
- push_to_hub=False
148
  )
149
 
150
- model.eval()
 
151
 
152
  # ============================================================
153
- # 5️⃣ Inference helpers
154
  # ============================================================
155
  @torch.no_grad()
156
  def encode_for_inference(prompt: str) -> torch.Tensor:
@@ -182,7 +184,7 @@ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_h
182
  return chat_history
183
 
184
  # ============================================================
185
- # 6️⃣ Gradio UI
186
  # ============================================================
187
  with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
188
  gr.Markdown(
@@ -207,7 +209,7 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
207
  clear_btn.click(lambda: [], None, chatbot)
208
 
209
  # ============================================================
210
- # 7️⃣ Launch
211
  # ============================================================
212
  if __name__ == "__main__":
213
  demo.launch(show_error=True)
 
1
+ # prompt_enhancer_flashpack_cpu_publish.py
2
  import gc
3
  import torch
4
  import torch.nn as nn
 
13
  # 🖥 Force CPU mode
14
  # ============================================================
15
  device = torch.device("cpu")
16
+ torch.set_num_threads(4)
17
  print(f"🔧 Forcing device: {device} (CPU-only mode)")
18
 
19
  # ============================================================
20
  # 1️⃣ Define FlashPack model
21
  # ============================================================
22
  class GemmaTrainer(nn.Module, FlashPackMixin):
23
+ def __init__(self, input_dim: int = 768, hidden_dim: int = 512, output_dim: int = 768):
24
  super().__init__()
25
  self.fc1 = nn.Linear(input_dim, hidden_dim)
26
  self.relu = nn.ReLU()
 
58
  return tokenizer, embed_model, encode
59
 
60
  # ============================================================
61
+ # 3️⃣ Train and push FlashPack model
62
  # ============================================================
63
+ def train_and_push_flashpack(
64
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
 
 
 
 
65
  hf_repo: str = "rahul7star/FlashPack",
66
+ max_encode: int = 1000,
67
+ push_to_hub: bool = True,
68
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
69
 
 
70
  print("📦 Loading dataset...")
71
  dataset = load_dataset(dataset_name, split="train")
 
 
72
  limit = min(max_encode, len(dataset))
73
  dataset = dataset.select(range(limit))
74
  print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
75
 
76
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
 
77
 
 
78
  print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
79
  short_list, long_list = [], []
80
  for i, item in enumerate(dataset):
 
89
  long_embeddings = torch.vstack(long_list)
90
  print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
91
 
 
92
  model = GemmaTrainer(
93
  input_dim=short_embeddings.shape[1],
94
  hidden_dim=min(512, short_embeddings.shape[1]),
 
124
 
125
  print("✅ Training finished!")
126
 
 
127
  if push_to_hub:
128
+ print(f"📤 Pushing model to Hugging Face repo: {hf_repo} ...")
129
  model.save_flashpack(hf_repo, target_dtype=torch.float32, push_to_hub=True)
130
  print(f"✅ Model pushed to HF repo: {hf_repo}")
131
 
132
  return model, dataset, embed_model, tokenizer, long_embeddings
133
 
134
  # ============================================================
135
+ # 4️⃣ Load trained model from HF repo
136
+ # ============================================================
137
+ def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
138
+ model = GemmaTrainer.load_flashpack(hf_repo)
139
+ model.eval()
140
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
141
+ return model, tokenizer, embed_model
142
+
143
+ # ============================================================
144
+ # 5️⃣ Run training + push, then reload
145
  # ============================================================
146
+ model, dataset, embed_model, tokenizer, long_embeddings = train_and_push_flashpack(
147
+ max_encode=1000, # CPU-safe
148
+ push_to_hub=True
149
  )
150
 
151
+ # reload to ensure FlashPack workflow works
152
+ model, tokenizer, embed_model = load_flashpack_model("rahul7star/FlashPack")
153
 
154
  # ============================================================
155
+ # 6️⃣ Inference helpers
156
  # ============================================================
157
  @torch.no_grad()
158
  def encode_for_inference(prompt: str) -> torch.Tensor:
 
184
  return chat_history
185
 
186
  # ============================================================
187
+ # 7️⃣ Gradio UI
188
  # ============================================================
189
  with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
190
  gr.Markdown(
 
209
  clear_btn.click(lambda: [], None, chatbot)
210
 
211
  # ============================================================
212
+ # 8️⃣ Launch
213
  # ============================================================
214
  if __name__ == "__main__":
215
  demo.launch(show_error=True)