AlekseyCalvin commited on
Commit
459f6e8
·
verified ·
1 Parent(s): e859d40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -3
app.py CHANGED
@@ -16,7 +16,13 @@ from safetensors.torch import load_file, save_file
16
  from tqdm import tqdm
17
 
18
  # --- Import Helpers ---
19
- from merge_utils import execute_mergekit_config, build_full_merge_config, build_moe_config
 
 
 
 
 
 
20
  from dare_utils import task_dare_custom
21
 
22
  # --- Memory Efficient Safetensors ---
@@ -601,16 +607,17 @@ def task_full_mergekit_merge(hf_token, models_text, method, dtype, base_model, w
601
  # TAB 6: MOE CREATION
602
  # =================================================================================
603
 
604
- def task_moe_create(hf_token, base_model, experts_text, gate_mode, dtype, tok_source, shard_size, out_repo, private):
605
  cleanup_temp()
606
  if not hf_token or not out_repo: return "Error: Token and Output Repo required."
607
  login(hf_token.strip())
608
 
609
  experts = [e.strip() for e in experts_text.split('\n') if e.strip()]
 
610
 
611
  # 1. Build Config
612
  config = build_moe_config(
613
- base_model=base_model, experts=experts, gate_mode=gate_mode,
614
  dtype=dtype, tokenizer_source=tok_source
615
  )
616
 
@@ -624,6 +631,23 @@ def task_moe_create(hf_token, base_model, experts_text, gate_mode, dtype, tok_so
624
  except Exception as e:
625
  return f"MoE Error: {e}"
626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  # =================================================================================
628
  # UI
629
  # =================================================================================
@@ -760,6 +784,22 @@ with gr.Blocks() as demo:
760
  t7_res = gr.Textbox(label="Result")
761
 
762
  t7_btn.click(task_dare_custom, [t7_token, t7_base, t7_ft, t7_ratio, t7_mask, t7_out, t7_priv], t7_res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
 
764
  if __name__ == "__main__":
765
  demo.queue().launch(css=css, ssr_mode=False)
 
16
  from tqdm import tqdm
17
 
18
  # --- Import Helpers ---
19
+ from merge_utils import (
20
+ execute_mergekit_config,
21
+ execute_raw_pytorch,
22
+ build_full_merge_config,
23
+ build_moe_config,
24
+ build_raw_config
25
+ )
26
  from dare_utils import task_dare_custom
27
 
28
  # --- Memory Efficient Safetensors ---
 
607
  # TAB 6: MOE CREATION
608
  # =================================================================================
609
 
610
+ def task_moe_create(hf_token, base_model, experts_text, prompts_text, gate_mode, dtype, tok_source, shard_size, out_repo, private):
611
  cleanup_temp()
612
  if not hf_token or not out_repo: return "Error: Token and Output Repo required."
613
  login(hf_token.strip())
614
 
615
  experts = [e.strip() for e in experts_text.split('\n') if e.strip()]
616
+ prompts = [p.strip() for p in prompts_text.split('\n') if p.strip()]
617
 
618
  # 1. Build Config
619
  config = build_moe_config(
620
+ base_model=base_model, experts=experts, prompts=prompts, gate_mode=gate_mode,
621
  dtype=dtype, tokenizer_source=tok_source
622
  )
623
 
 
631
  except Exception as e:
632
  return f"MoE Error: {e}"
633
 
634
+ # --- TAB 8: Raw PyTorch (New) ---
635
+ def task_raw_pytorch(hf_token, models_text, method, dtype, base_model, weights, shard_size, out_repo, private):
636
+ cleanup_temp()
637
+ if not hf_token or not out_repo: return "Error: Token and Output Repo required."
638
+ login(hf_token.strip())
639
+
640
+ models = [m.strip() for m in models_text.split('\n') if m.strip()]
641
+ config = build_raw_config(method, models, base_model, dtype, weights)
642
+
643
+ out_path = TempDir / "raw_merged"
644
+ try:
645
+ execute_raw_pytorch(config, str(out_path), shard_size)
646
+ api.create_repo(repo_id=out_repo, private=private, exist_ok=True, token=hf_token)
647
+ api.upload_folder(folder_path=str(out_path), repo_id=out_repo, token=hf_token)
648
+ return f"Success! Raw merge uploaded to {out_repo}"
649
+ except Exception as e: return f"Raw Merge Error: {e}"
650
+
651
  # =================================================================================
652
  # UI
653
  # =================================================================================
 
784
  t7_res = gr.Textbox(label="Result")
785
 
786
  t7_btn.click(task_dare_custom, [t7_token, t7_base, t7_ft, t7_ratio, t7_mask, t7_out, t7_priv], t7_res)
787
+
788
+ with gr.Tab("Raw PyTorch Merge"):
789
+ gr.Markdown("### 🧠 Raw Weight Merging (Non-Transformers)")
790
+ t8_token = gr.Textbox(label="HF Token", type="password")
791
+ t8_method = gr.Dropdown(["Linear", "TIES", "Task_Arithmetic"], value="Linear", label="Method")
792
+ t8_models = gr.TextArea(label="Models (Path/Repo)")
793
+ with gr.Row():
794
+ t8_base = gr.Textbox(label="Base Model (Optional)")
795
+ t8_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="Dtype")
796
+ t8_weights = gr.Textbox(label="Weights")
797
+ t8_shard = gr.Slider(0.5, 10, 2.0, label="Shard Size (GB)")
798
+ t8_out = gr.Textbox(label="Output Repo")
799
+ t8_priv = gr.Checkbox(label="Private", value=True)
800
+ t8_btn = gr.Button("Merge Raw Weights")
801
+ t8_res = gr.Textbox(label="Result")
802
+ t8_btn.click(task_raw_pytorch, [t8_token, t8_models, t8_method, t8_dtype, t8_base, t8_weights, t8_shard, t8_out, t8_priv], t8_res)
803
 
804
  if __name__ == "__main__":
805
  demo.queue().launch(css=css, ssr_mode=False)