rahul7star commited on
Commit
8191e5c
·
verified ·
1 Parent(s): fad9b48

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +51 -4
app_flash.py CHANGED
@@ -60,6 +60,53 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
60
  # ============================================================
61
  # 3️⃣ Train and push FlashPack model
62
  # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def train_and_push_flashpack(
64
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
65
  hf_repo: str = "rahul7star/FlashPack",
@@ -125,10 +172,10 @@ def train_and_push_flashpack(
125
  print("✅ Training finished!")
126
 
127
  if push_to_hub:
128
- print(f"📤 Pushing model to Hugging Face repo: {hf_repo} ...")
129
- model.save_flashpack(hf_repo, target_dtype=torch.float32) # save locally
130
- model.push_to_hub(hf_repo) # then push to HF repo
131
- print(f"✅ Model pushed to HF repo: {hf_repo}")
132
 
133
  return model, dataset, embed_model, tokenizer, long_embeddings
134
 
 
60
  # ============================================================
61
  # 3️⃣ Train and push FlashPack model
62
  # ============================================================
63
+ import os
64
+ import tempfile
65
+ from huggingface_hub import hf_hub_download, HfApi
66
+
67
+ # ------------------------------------------------------------
68
+ # Utility to push FlashPack model to HF using upload_file
69
+ # ------------------------------------------------------------
70
+ def push_flashpack_to_hub_local(model, hf_repo: str, size_limit_mb=100):
71
+ """
72
+ Saves FlashPack model locally, then uploads files to Hugging Face repo individually.
73
+ Avoids push_to_hub() call which may not be compatible.
74
+ """
75
+ logs = []
76
+
77
+ # 1️⃣ Save model locally to temp directory
78
+ temp_dir = tempfile.mkdtemp()
79
+ model.save_flashpack(temp_dir)
80
+ logs.append(f"📦 Model saved locally to {temp_dir}")
81
+
82
+ # 2️⃣ List saved files
83
+ files = [f for f in os.listdir(temp_dir) if os.path.isfile(os.path.join(temp_dir, f))]
84
+ logs.append(f"📄 Found {len(files)} files to upload.")
85
+
86
+ api = HfApi()
87
+ for i, file in enumerate(files, start=1):
88
+ file_path = os.path.join(temp_dir, file)
89
+ size_mb = os.path.getsize(file_path) / (1024*1024)
90
+ if size_mb > size_limit_mb:
91
+ logs.append(f"⏭️ Skipping {file}, size {size_mb:.2f}MB > limit {size_limit_mb}MB")
92
+ continue
93
+
94
+ logs.append(f"⬆️ Uploading {file} ({i}/{len(files)}) ...")
95
+ try:
96
+ api.upload_file(
97
+ path_or_fileobj=file_path,
98
+ path_in_repo=file,
99
+ repo_id=hf_repo,
100
+ repo_type="model",
101
+ commit_message=f"Add {file} from FlashPack training"
102
+ )
103
+ logs.append(f"✅ Uploaded {file}")
104
+ except Exception as e:
105
+ logs.append(f"⚠️ Failed to upload {file}: {e}")
106
+
107
+ logs.append(f"🎉 Model push complete to {hf_repo}")
108
+ return "\n".join(logs)
109
+
110
  def train_and_push_flashpack(
111
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
112
  hf_repo: str = "rahul7star/FlashPack",
 
172
  print("✅ Training finished!")
173
 
174
  if push_to_hub:
175
+ print("📤 Pushing FlashPack model to Hugging Face repo...")
176
+ logs = push_flashpack_to_hub_local(model, hf_repo)
177
+ print(logs)
178
+
179
 
180
  return model, dataset, embed_model, tokenizer, long_embeddings
181