| import os |
| import gc |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import tempfile |
| import gradio as gr |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModel |
| from flashpack import FlashPackMixin |
| from huggingface_hub import Repository, list_repo_files, hf_hub_download |
|
|
| device = torch.device("cpu") |
| torch.set_num_threads(4) |
| print(f"π§ Using device: {device} (CPU-only mode)") |
|
|
| |
| |
| |
| class GemmaTrainer(nn.Module, FlashPackMixin): |
| def __init__(self): |
| super().__init__() |
| input_dim = 1536 |
| hidden_dim = 1024 |
| output_dim = 1536 |
| self.fc1 = nn.Linear(input_dim, hidden_dim) |
| self.relu = nn.ReLU() |
| self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
| self.fc3 = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.fc1(x) |
| x = self.relu(x) |
| x = self.fc2(x) |
| x = self.relu(x) |
| x = self.fc3(x) |
| return x |
|
|
| |
| |
| |
| def build_encoder(model_name="gpt2", max_length: int = 128): |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| embed_model = AutoModel.from_pretrained(model_name).to(device) |
| embed_model.eval() |
|
|
| @torch.no_grad() |
| def encode(prompt: str) -> torch.Tensor: |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, |
| padding="max_length", max_length=max_length).to(device) |
| last_hidden = embed_model(**inputs).last_hidden_state |
| mean_pool = last_hidden.mean(dim=1) |
| max_pool, _ = last_hidden.max(dim=1) |
| return torch.cat([mean_pool, max_pool], dim=1).cpu() |
|
|
| return tokenizer, embed_model, encode |
|
|
| |
| |
| |
| def push_flashpack_model_to_hf(model, hf_repo, log_fn): |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| log_fn(f"π¦ Preparing repository {hf_repo}...") |
| repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True) |
| model.save_flashpack(os.path.join(tmp_dir, "model.flashpack"), target_dtype=torch.float32) |
| with open(os.path.join(tmp_dir, "README.md"), "w") as f: |
| f.write("# FlashPack Model\nTrained locally and pushed to HF.") |
| log_fn("β³ Pushing model to Hugging Face...") |
| repo.push_to_hub() |
| log_fn(f"β
Model pushed to {hf_repo}") |
|
|
| |
| |
| |
| def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset", |
| hf_repo="rahul7star/FlashPack", |
| max_encode=1000): |
| logs = [] |
|
|
| def log_fn(msg): |
| logs.append(msg) |
| print(msg) |
|
|
| log_fn("π¦ Loading dataset...") |
| dataset = load_dataset(dataset_name, split="train").select(range(max_encode)) |
| log_fn(f"β
Loaded {len(dataset)} samples") |
|
|
| tokenizer, embed_model, encode_fn = build_encoder("gpt2") |
|
|
| |
| s_list, l_list = [], [] |
| for i, item in enumerate(dataset): |
| s_list.append(encode_fn(item["short_prompt"])) |
| l_list.append(encode_fn(item["long_prompt"])) |
| if (i + 1) % 50 == 0: |
| log_fn(f" β Encoded {i + 1}/{len(dataset)}") |
| gc.collect() |
| short_emb, long_emb = torch.vstack(s_list), torch.vstack(l_list) |
|
|
| model = GemmaTrainer() |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) |
| loss_fn = nn.CosineSimilarity(dim=1) |
|
|
| log_fn("π Training model...") |
| for epoch in range(20): |
| model.train() |
| optimizer.zero_grad() |
| preds = model(short_emb) |
| loss = 1 - loss_fn(preds, long_emb).mean() |
| loss.backward() |
| optimizer.step() |
| log_fn(f"Epoch {epoch+1}/20 | Loss: {loss.item():.5f}") |
| if loss.item() < 0.01: |
| log_fn("π― Early stopping.") |
| break |
|
|
| push_flashpack_model_to_hf(model, hf_repo, log_fn) |
| tokenizer, embed_model, encode_fn = build_encoder("gpt2") |
|
|
| @torch.no_grad() |
| def enhance_fn(prompt, chat): |
| chat = chat or [] |
| short_emb = encode_fn(prompt) |
| mapped = model(short_emb.to(device)).cpu() |
| long_prompt = f"π Enhanced prompt (embedding-based) for: {prompt}" |
| chat.append({"role": "user", "content": prompt}) |
| chat.append({"role": "assistant", "content": long_prompt}) |
| return chat |
|
|
| return model, tokenizer, embed_model, enhance_fn, logs |
|
|
| |
| |
| |
| def get_flashpack_model(hf_repo="rahul7star/FlashPack"): |
| local_model_path = "model.flashpack" |
|
|
| if os.path.exists(local_model_path): |
| print("β
Loading local model") |
| else: |
| try: |
| files = list_repo_files(hf_repo) |
| if "model.flashpack" in files: |
| print("β
Downloading model from HF") |
| local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack") |
| else: |
| print("π« No pretrained model found") |
| return None, None, None, None |
| except Exception as e: |
| print(f"β οΈ Error accessing HF: {e}") |
| return None, None, None, None |
|
|
| model = GemmaTrainer().from_flashpack(local_model_path) |
| model.eval() |
| tokenizer, embed_model, encode_fn = build_encoder("gpt2") |
|
|
| @torch.no_grad() |
| def enhance_fn(prompt, chat): |
| chat = chat or [] |
| short_emb = encode_fn(prompt).to(device) |
| mapped = model(short_emb).cpu() |
| long_prompt = f"π Enhanced prompt (embedding-based) for: {prompt}" |
| chat.append({"role": "user", "content": prompt}) |
| chat.append({"role": "assistant", "content": long_prompt}) |
| return chat |
|
|
| return model, tokenizer, embed_model, enhance_fn |
|
|
| |
| |
| |
| with gr.Blocks(title="β¨ FlashPack Prompt Enhancer") as demo: |
| gr.Markdown("## π§ FlashPack Prompt Enhancer (CPU)\nShort β Long prompt expander") |
|
|
| chatbot = gr.Chatbot(height=400, type="messages") |
| user_input = gr.Textbox(label="Your prompt") |
| send_btn = gr.Button("π Enhance Prompt", variant="primary") |
| clear_btn = gr.Button("π§Ή Clear") |
| train_btn = gr.Button("π§© Train Model", variant="secondary") |
| log_output = gr.Textbox(label="Logs", lines=15) |
|
|
| |
| model, tokenizer, embed_model, enhance_fn = get_flashpack_model() |
| logs = [] |
|
|
| if enhance_fn is None: |
| def enhance_fn(prompt, chat): |
| chat = chat or [] |
| chat.append({"role": "assistant", |
| "content": "β οΈ No pretrained model found. Please click 'Train Model' to create one."}) |
| return chat |
| logs.append("β οΈ No pretrained model found. Ready to train.") |
| else: |
| logs.append("β
Model loaded β ready to enhance.") |
|
|
| |
| send_btn.click(enhance_fn, [user_input, chatbot], chatbot) |
| user_input.submit(enhance_fn, [user_input, chatbot], chatbot) |
| clear_btn.click(lambda: [], None, chatbot) |
|
|
| def retrain(): |
| global model, tokenizer, embed_model, enhance_fn, logs |
| logs = ["π Training model, please wait..."] |
| model, tokenizer, embed_model, enhance_fn, train_logs = train_flashpack_model() |
| logs.extend(train_logs) |
| return gr.Textbox.update(value="\n".join(logs)) |
|
|
| train_btn.click(retrain, None, log_output) |
|
|
| if __name__ == "__main__": |
| demo.launch(show_error=True) |
|
|