Update app_flash.py
Browse files- app_flash.py +27 -25
app_flash.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
#
|
| 2 |
import gc
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
@@ -13,14 +13,14 @@ from typing import Tuple
|
|
| 13 |
# 🖥 Force CPU mode
|
| 14 |
# ============================================================
|
| 15 |
device = torch.device("cpu")
|
| 16 |
-
torch.set_num_threads(4)
|
| 17 |
print(f"🔧 Forcing device: {device} (CPU-only mode)")
|
| 18 |
|
| 19 |
# ============================================================
|
| 20 |
# 1️⃣ Define FlashPack model
|
| 21 |
# ============================================================
|
| 22 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 23 |
-
def __init__(self, input_dim: int = 768, hidden_dim: int =
|
| 24 |
super().__init__()
|
| 25 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 26 |
self.relu = nn.ReLU()
|
|
@@ -58,30 +58,23 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
|
|
| 58 |
return tokenizer, embed_model, encode
|
| 59 |
|
| 60 |
# ============================================================
|
| 61 |
-
# 3️⃣ Train FlashPack
|
| 62 |
# ============================================================
|
| 63 |
-
def
|
| 64 |
dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
|
| 65 |
-
model_name: str = "gpt2",
|
| 66 |
-
max_length: int = 32,
|
| 67 |
-
max_encode: int = 1000, # use smaller number for CPU
|
| 68 |
-
push_to_hub: bool = False,
|
| 69 |
hf_repo: str = "rahul7star/FlashPack",
|
|
|
|
|
|
|
| 70 |
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
|
| 71 |
|
| 72 |
-
# 1️⃣ Load dataset
|
| 73 |
print("📦 Loading dataset...")
|
| 74 |
dataset = load_dataset(dataset_name, split="train")
|
| 75 |
-
|
| 76 |
-
# Limit dataset to max_encode prompts
|
| 77 |
limit = min(max_encode, len(dataset))
|
| 78 |
dataset = dataset.select(range(limit))
|
| 79 |
print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
|
| 80 |
|
| 81 |
-
|
| 82 |
-
tokenizer, embed_model, encode_fn = build_encoder(model_name, max_length)
|
| 83 |
|
| 84 |
-
# 3️⃣ Encode dataset
|
| 85 |
print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
|
| 86 |
short_list, long_list = [], []
|
| 87 |
for i, item in enumerate(dataset):
|
|
@@ -96,7 +89,6 @@ def train_flashpack_model(
|
|
| 96 |
long_embeddings = torch.vstack(long_list)
|
| 97 |
print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
|
| 98 |
|
| 99 |
-
# 4️⃣ Initialize & train model
|
| 100 |
model = GemmaTrainer(
|
| 101 |
input_dim=short_embeddings.shape[1],
|
| 102 |
hidden_dim=min(512, short_embeddings.shape[1]),
|
|
@@ -132,25 +124,35 @@ def train_flashpack_model(
|
|
| 132 |
|
| 133 |
print("✅ Training finished!")
|
| 134 |
|
| 135 |
-
# 5️⃣ Push to HF repo if requested
|
| 136 |
if push_to_hub:
|
|
|
|
| 137 |
model.save_flashpack(hf_repo, target_dtype=torch.float32, push_to_hub=True)
|
| 138 |
print(f"✅ Model pushed to HF repo: {hf_repo}")
|
| 139 |
|
| 140 |
return model, dataset, embed_model, tokenizer, long_embeddings
|
| 141 |
|
| 142 |
# ============================================================
|
| 143 |
-
# 4️⃣
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# ============================================================
|
| 145 |
-
model, dataset, embed_model, tokenizer, long_embeddings =
|
| 146 |
-
max_encode=1000, #
|
| 147 |
-
push_to_hub=
|
| 148 |
)
|
| 149 |
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
# ============================================================
|
| 153 |
-
#
|
| 154 |
# ============================================================
|
| 155 |
@torch.no_grad()
|
| 156 |
def encode_for_inference(prompt: str) -> torch.Tensor:
|
|
@@ -182,7 +184,7 @@ def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_h
|
|
| 182 |
return chat_history
|
| 183 |
|
| 184 |
# ============================================================
|
| 185 |
-
#
|
| 186 |
# ============================================================
|
| 187 |
with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
|
| 188 |
gr.Markdown(
|
|
@@ -207,7 +209,7 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
|
|
| 207 |
clear_btn.click(lambda: [], None, chatbot)
|
| 208 |
|
| 209 |
# ============================================================
|
| 210 |
-
#
|
| 211 |
# ============================================================
|
| 212 |
if __name__ == "__main__":
|
| 213 |
demo.launch(show_error=True)
|
|
|
|
| 1 |
+
# prompt_enhancer_flashpack_cpu_publish.py
|
| 2 |
import gc
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
| 13 |
# 🖥 Force CPU mode
|
| 14 |
# ============================================================
|
| 15 |
device = torch.device("cpu")
|
| 16 |
+
torch.set_num_threads(4)
|
| 17 |
print(f"🔧 Forcing device: {device} (CPU-only mode)")
|
| 18 |
|
| 19 |
# ============================================================
|
| 20 |
# 1️⃣ Define FlashPack model
|
| 21 |
# ============================================================
|
| 22 |
class GemmaTrainer(nn.Module, FlashPackMixin):
|
| 23 |
+
def __init__(self, input_dim: int = 768, hidden_dim: int = 512, output_dim: int = 768):
|
| 24 |
super().__init__()
|
| 25 |
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 26 |
self.relu = nn.ReLU()
|
|
|
|
| 58 |
return tokenizer, embed_model, encode
|
| 59 |
|
| 60 |
# ============================================================
|
| 61 |
+
# 3️⃣ Train and push FlashPack model
|
| 62 |
# ============================================================
|
| 63 |
+
def train_and_push_flashpack(
|
| 64 |
dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
hf_repo: str = "rahul7star/FlashPack",
|
| 66 |
+
max_encode: int = 1000,
|
| 67 |
+
push_to_hub: bool = True,
|
| 68 |
) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
|
| 69 |
|
|
|
|
| 70 |
print("📦 Loading dataset...")
|
| 71 |
dataset = load_dataset(dataset_name, split="train")
|
|
|
|
|
|
|
| 72 |
limit = min(max_encode, len(dataset))
|
| 73 |
dataset = dataset.select(range(limit))
|
| 74 |
print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
|
| 75 |
|
| 76 |
+
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
|
|
|
|
| 77 |
|
|
|
|
| 78 |
print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
|
| 79 |
short_list, long_list = [], []
|
| 80 |
for i, item in enumerate(dataset):
|
|
|
|
| 89 |
long_embeddings = torch.vstack(long_list)
|
| 90 |
print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
|
| 91 |
|
|
|
|
| 92 |
model = GemmaTrainer(
|
| 93 |
input_dim=short_embeddings.shape[1],
|
| 94 |
hidden_dim=min(512, short_embeddings.shape[1]),
|
|
|
|
| 124 |
|
| 125 |
print("✅ Training finished!")
|
| 126 |
|
|
|
|
| 127 |
if push_to_hub:
|
| 128 |
+
print(f"📤 Pushing model to Hugging Face repo: {hf_repo} ...")
|
| 129 |
model.save_flashpack(hf_repo, target_dtype=torch.float32, push_to_hub=True)
|
| 130 |
print(f"✅ Model pushed to HF repo: {hf_repo}")
|
| 131 |
|
| 132 |
return model, dataset, embed_model, tokenizer, long_embeddings
|
| 133 |
|
| 134 |
# ============================================================
|
| 135 |
+
# 4️⃣ Load trained model from HF repo
|
| 136 |
+
# ============================================================
|
| 137 |
+
def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
|
| 138 |
+
model = GemmaTrainer.load_flashpack(hf_repo)
|
| 139 |
+
model.eval()
|
| 140 |
+
tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
|
| 141 |
+
return model, tokenizer, embed_model
|
| 142 |
+
|
| 143 |
+
# ============================================================
|
| 144 |
+
# 5️⃣ Run training + push, then reload
|
| 145 |
# ============================================================
|
| 146 |
+
model, dataset, embed_model, tokenizer, long_embeddings = train_and_push_flashpack(
|
| 147 |
+
max_encode=1000, # CPU-safe
|
| 148 |
+
push_to_hub=True
|
| 149 |
)
|
| 150 |
|
| 151 |
+
# reload to ensure FlashPack workflow works
|
| 152 |
+
model, tokenizer, embed_model = load_flashpack_model("rahul7star/FlashPack")
|
| 153 |
|
| 154 |
# ============================================================
|
| 155 |
+
# 6️⃣ Inference helpers
|
| 156 |
# ============================================================
|
| 157 |
@torch.no_grad()
|
| 158 |
def encode_for_inference(prompt: str) -> torch.Tensor:
|
|
|
|
| 184 |
return chat_history
|
| 185 |
|
| 186 |
# ============================================================
|
| 187 |
+
# 7️⃣ Gradio UI
|
| 188 |
# ============================================================
|
| 189 |
with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo:
|
| 190 |
gr.Markdown(
|
|
|
|
| 209 |
clear_btn.click(lambda: [], None, chatbot)
|
| 210 |
|
| 211 |
# ============================================================
|
| 212 |
+
# 8️⃣ Launch
|
| 213 |
# ============================================================
|
| 214 |
if __name__ == "__main__":
|
| 215 |
demo.launch(show_error=True)
|