| import gc |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import tempfile |
| import gradio as gr |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModel |
| from flashpack import FlashPackMixin |
| from huggingface_hub import Repository |
| from typing import Tuple |
|
|
| |
| |
| |
| device = torch.device("cpu") |
| torch.set_num_threads(4) |
| print(f"🔧 Using device: {device} (CPU-only)") |
|
|
| |
| |
| |
| class GemmaTrainer(nn.Module, FlashPackMixin): |
| def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 1536): |
| super().__init__() |
| self.fc1 = nn.Linear(input_dim, hidden_dim) |
| self.relu = nn.ReLU() |
| self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
| self.fc3 = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.fc1(x) |
| x = self.relu(x) |
| x = self.fc2(x) |
| x = self.relu(x) |
| x = self.fc3(x) |
| return x |
|
|
| |
| |
| |
| def build_encoder(model_name="gpt2", max_length: int = 128): |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| embed_model = AutoModel.from_pretrained(model_name).to(device) |
| embed_model.eval() |
|
|
| @torch.no_grad() |
| def encode(prompt: str) -> torch.Tensor: |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, |
| padding="max_length", max_length=max_length).to(device) |
| last_hidden = embed_model(**inputs).last_hidden_state |
| mean_pool = last_hidden.mean(dim=1) |
| max_pool, _ = last_hidden.max(dim=1) |
| return torch.cat([mean_pool, max_pool], dim=1).cpu() |
|
|
| return tokenizer, embed_model, encode |
|
|
| |
| |
| |
| def push_flashpack_model_to_hf(model, hf_repo: str): |
| logs = [] |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| logs.append(f"📂 Temporary directory: {tmp_dir}") |
| repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True) |
| pack_path = os.path.join(tmp_dir, "model.flashpack") |
| model.save_flashpack(pack_path, target_dtype=torch.float32) |
| readme_path = os.path.join(tmp_dir, "README.md") |
| with open(readme_path, "w") as f: |
| f.write("# FlashPack Model\nThis repo contains a FlashPack model.") |
| repo.push_to_hub() |
| logs.append(f"✅ Model pushed to HF: {hf_repo}") |
| return logs |
|
|
| |
| |
| |
| def train_flashpack_model( |
| dataset_name: str = "rahul7star/prompt-enhancer-dataset", |
| max_encode: int = 1000, |
| hidden_dim: int = 1024, |
| push_to_hub: bool = True, |
| hf_repo: str = "rahul7star/FlashPack" |
| ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]: |
|
|
| print("📦 Loading dataset...") |
| dataset = load_dataset(dataset_name, split="train") |
| limit = min(max_encode, len(dataset)) |
| dataset = dataset.select(range(limit)) |
| print(f"⚡ Using {len(dataset)} prompts for training") |
|
|
| tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128) |
|
|
| |
| short_list, long_list = [], [] |
| for i, item in enumerate(dataset): |
| short_list.append(encode_fn(item["short_prompt"])) |
| long_list.append(encode_fn(item["long_prompt"])) |
| if (i+1) % 50 == 0 or (i+1) == len(dataset): |
| print(f" → Encoded {i+1}/{limit} prompts") |
| gc.collect() |
|
|
| short_embeddings = torch.vstack(short_list) |
| long_embeddings = torch.vstack(long_list) |
| print(f"✅ Encoded embeddings shape: short {short_embeddings.shape}, long {long_embeddings.shape}") |
|
|
| input_dim = short_embeddings.shape[1] |
| output_dim = long_embeddings.shape[1] |
|
|
| model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device) |
|
|
| criterion = nn.CosineSimilarity(dim=1) |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) |
| max_epochs = 50 |
| batch_size = 32 |
| n = short_embeddings.shape[0] |
|
|
| print("🚀 Training model...") |
| for epoch in range(max_epochs): |
| model.train() |
| epoch_loss = 0.0 |
| perm = torch.randperm(n) |
| for start in range(0, n, batch_size): |
| idx = perm[start:start+batch_size] |
| inputs = short_embeddings[idx].to(device) |
| targets = long_embeddings[idx].to(device) |
|
|
| optimizer.zero_grad() |
| outputs = model(inputs) |
| loss = 1 - criterion(outputs, targets).mean() |
| loss.backward() |
| optimizer.step() |
| epoch_loss += loss.item() * inputs.size(0) |
|
|
| epoch_loss /= n |
| if epoch % 5 == 0 or epoch == max_epochs-1: |
| print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}") |
|
|
| print("✅ Training finished!") |
|
|
| if push_to_hub: |
| logs = push_flashpack_model_to_hf(model, hf_repo) |
| for log in logs: |
| print(log) |
|
|
| return model, dataset, embed_model, tokenizer, long_embeddings |
|
|
| |
| |
| |
| def get_flashpack_model(hf_repo="rahul7star/FlashPack"): |
| try: |
| print(f"🔁 Attempting to load FlashPack model from {hf_repo}") |
| model = GemmaTrainer.from_flashpack(hf_repo) |
| model.eval() |
| tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128) |
| return model, tokenizer, embed_model |
| except Exception as e: |
| print(f"⚠️ Load failed: {e}") |
| print("⏬ Training a new FlashPack model locally...") |
| model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model() |
| push_flashpack_model_to_hf(model, hf_repo) |
| return model, tokenizer, embed_model, dataset, long_embeddings |
|
|
| |
| |
| |
| model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model() |
|
|
| |
| |
| |
| @torch.no_grad() |
| def encode_for_inference(prompt: str) -> torch.Tensor: |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, |
| padding="max_length", max_length=128).to(device) |
| last_hidden = embed_model(**inputs).last_hidden_state |
| mean_pool = last_hidden.mean(dim=1) |
| max_pool, _ = last_hidden.max(dim=1) |
| return torch.cat([mean_pool, max_pool], dim=1).cpu() |
|
|
| def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history): |
| chat_history = chat_history or [] |
| short_emb = encode_for_inference(user_prompt) |
| mapped = model(short_emb.to(device)).cpu() |
|
|
| sims = (long_embeddings @ mapped.t()).squeeze(1) |
| long_norms = long_embeddings.norm(dim=1) |
| mapped_norm = mapped.norm() |
| sims = sims / (long_norms * (mapped_norm + 1e-12)) |
|
|
| best_idx = int(sims.argmax().item()) |
| enhanced_prompt = dataset[best_idx]["long_prompt"] |
|
|
| chat_history.append({"role": "user", "content": user_prompt}) |
| chat_history.append({"role": "assistant", "content": enhanced_prompt}) |
| return chat_history |
|
|
| |
| |
| |
| with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft()) as demo: |
| gr.Markdown( |
| """ |
| # ✨ Prompt Enhancer (FlashPack mapper) |
| Enter a short prompt, and the model will **expand it with details and creative context**. |
| (CPU-only mode.) |
| """ |
| ) |
|
|
| with gr.Row(): |
| chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") |
| with gr.Column(scale=1): |
| user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3) |
| temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") |
| max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") |
| send_btn = gr.Button("🚀 Enhance Prompt", variant="primary") |
| clear_btn = gr.Button("🧹 Clear Chat") |
|
|
| send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
| user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) |
| clear_btn.click(lambda: [], None, chatbot) |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| demo.launch(show_error=True) |