Ranjit0034 commited on
Commit
a60c3fc
Β·
verified Β·
1 Parent(s): fef7470

Upload scripts/download_base_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/download_base_model.py +269 -0
scripts/download_base_model.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Model Download and Setup for FinEE v2.0
4
+ ========================================
5
+
6
+ Downloads and prepares base models for fine-tuning:
7
+ - Llama 3.1 8B Instruct (Primary)
8
+ - Qwen2.5 7B Instruct (Backup)
9
+
10
+ Supports:
11
+ - MLX format for Apple Silicon
12
+ - PyTorch/Transformers format
13
+ - GGUF for llama.cpp
14
+ """
15
+
16
+ import argparse
17
+ import os
18
+ import subprocess
19
+ import sys
20
+ from pathlib import Path
21
+
22
+
23
+ MODELS = {
24
+ "llama-3.1-8b": {
25
+ "hf_name": "meta-llama/Llama-3.1-8B-Instruct",
26
+ "mlx_name": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
27
+ "gguf_name": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
28
+ "description": "Llama 3.1 8B Instruct - Best instruction-following",
29
+ "size": "8B",
30
+ "context": "128K",
31
+ },
32
+ "qwen2.5-7b": {
33
+ "hf_name": "Qwen/Qwen2.5-7B-Instruct",
34
+ "mlx_name": "mlx-community/Qwen2.5-7B-Instruct-4bit",
35
+ "gguf_name": "Qwen/Qwen2.5-7B-Instruct-GGUF",
36
+ "description": "Qwen 2.5 7B - Excellent multilingual support",
37
+ "size": "7B",
38
+ "context": "128K",
39
+ },
40
+ "mistral-7b": {
41
+ "hf_name": "mistralai/Mistral-7B-Instruct-v0.3",
42
+ "mlx_name": "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
43
+ "gguf_name": "bartowski/Mistral-7B-Instruct-v0.3-GGUF",
44
+ "description": "Mistral 7B - Fast and efficient",
45
+ "size": "7B",
46
+ "context": "32K",
47
+ },
48
+ "phi-3-medium": {
49
+ "hf_name": "microsoft/Phi-3-medium-128k-instruct",
50
+ "mlx_name": "mlx-community/Phi-3-medium-128k-instruct-4bit",
51
+ "description": "Phi-3 Medium - Compact but powerful",
52
+ "size": "14B",
53
+ "context": "128K",
54
+ },
55
+ }
56
+
57
+
58
+ def download_mlx_model(model_key: str, output_dir: Path):
59
+ """Download model in MLX format."""
60
+ model = MODELS[model_key]
61
+ mlx_name = model.get("mlx_name")
62
+
63
+ if not mlx_name:
64
+ print(f"❌ No MLX version available for {model_key}")
65
+ return False
66
+
67
+ print(f"\nπŸ“₯ Downloading {model_key} (MLX format)...")
68
+ print(f" From: {mlx_name}")
69
+
70
+ output_path = output_dir / model_key / "mlx"
71
+ output_path.mkdir(parents=True, exist_ok=True)
72
+
73
+ try:
74
+ from huggingface_hub import snapshot_download
75
+
76
+ snapshot_download(
77
+ repo_id=mlx_name,
78
+ local_dir=str(output_path),
79
+ local_dir_use_symlinks=False,
80
+ )
81
+
82
+ print(f"βœ… Downloaded to: {output_path}")
83
+ return True
84
+
85
+ except Exception as e:
86
+ print(f"❌ Download failed: {e}")
87
+ return False
88
+
89
+
90
+ def download_hf_model(model_key: str, output_dir: Path):
91
+ """Download model in HuggingFace format."""
92
+ model = MODELS[model_key]
93
+ hf_name = model["hf_name"]
94
+
95
+ print(f"\nπŸ“₯ Downloading {model_key} (HuggingFace format)...")
96
+ print(f" From: {hf_name}")
97
+
98
+ output_path = output_dir / model_key / "hf"
99
+ output_path.mkdir(parents=True, exist_ok=True)
100
+
101
+ try:
102
+ from huggingface_hub import snapshot_download
103
+
104
+ snapshot_download(
105
+ repo_id=hf_name,
106
+ local_dir=str(output_path),
107
+ local_dir_use_symlinks=False,
108
+ ignore_patterns=["*.bin", "*.h5"], # Prefer safetensors
109
+ )
110
+
111
+ print(f"βœ… Downloaded to: {output_path}")
112
+ return True
113
+
114
+ except Exception as e:
115
+ print(f"❌ Download failed: {e}")
116
+ print(" Note: Some models require HuggingFace login")
117
+ print(" Run: huggingface-cli login")
118
+ return False
119
+
120
+
121
+ def download_gguf_model(model_key: str, output_dir: Path, quant: str = "Q4_K_M"):
122
+ """Download GGUF quantized model."""
123
+ model = MODELS[model_key]
124
+ gguf_name = model.get("gguf_name")
125
+
126
+ if not gguf_name:
127
+ print(f"❌ No GGUF version available for {model_key}")
128
+ return False
129
+
130
+ print(f"\nπŸ“₯ Downloading {model_key} (GGUF {quant} format)...")
131
+ print(f" From: {gguf_name}")
132
+
133
+ output_path = output_dir / model_key / "gguf"
134
+ output_path.mkdir(parents=True, exist_ok=True)
135
+
136
+ try:
137
+ from huggingface_hub import hf_hub_download
138
+
139
+ # Find the right quantization file
140
+ filename = f"*{quant}*.gguf"
141
+
142
+ hf_hub_download(
143
+ repo_id=gguf_name,
144
+ filename=filename,
145
+ local_dir=str(output_path),
146
+ local_dir_use_symlinks=False,
147
+ )
148
+
149
+ print(f"βœ… Downloaded to: {output_path}")
150
+ return True
151
+
152
+ except Exception as e:
153
+ print(f"❌ Download failed: {e}")
154
+ return False
155
+
156
+
157
+ def convert_to_mlx(model_path: Path, output_path: Path, quantize: bool = True):
158
+ """Convert HuggingFace model to MLX format."""
159
+ print(f"\nπŸ”„ Converting to MLX format...")
160
+
161
+ cmd = [
162
+ sys.executable, "-m", "mlx_lm.convert",
163
+ "--hf-path", str(model_path),
164
+ "--mlx-path", str(output_path),
165
+ ]
166
+
167
+ if quantize:
168
+ cmd.extend(["--quantize", "--q-bits", "4"])
169
+
170
+ try:
171
+ subprocess.run(cmd, check=True)
172
+ print(f"βœ… Converted to: {output_path}")
173
+ return True
174
+ except subprocess.CalledProcessError as e:
175
+ print(f"❌ Conversion failed: {e}")
176
+ return False
177
+
178
+
179
+ def verify_model(model_path: Path, backend: str = "mlx"):
180
+ """Verify model can be loaded."""
181
+ print(f"\nπŸ” Verifying model at {model_path}...")
182
+
183
+ if backend == "mlx":
184
+ try:
185
+ from mlx_lm import load, generate
186
+
187
+ model, tokenizer = load(str(model_path))
188
+
189
+ # Quick test
190
+ output = generate(model, tokenizer, "Hello", max_tokens=10)
191
+ print(f"βœ… Model loaded successfully!")
192
+ print(f" Test output: {output[:50]}...")
193
+ return True
194
+ except Exception as e:
195
+ print(f"❌ Verification failed: {e}")
196
+ return False
197
+
198
+ elif backend == "transformers":
199
+ try:
200
+ from transformers import AutoModelForCausalLM, AutoTokenizer
201
+
202
+ tokenizer = AutoTokenizer.from_pretrained(str(model_path))
203
+ model = AutoModelForCausalLM.from_pretrained(str(model_path))
204
+
205
+ print(f"βœ… Model loaded successfully!")
206
+ return True
207
+ except Exception as e:
208
+ print(f"❌ Verification failed: {e}")
209
+ return False
210
+
211
+
212
+ def list_models():
213
+ """List available models."""
214
+ print("\nπŸ“‹ Available Models:\n")
215
+ print(f"{'Model':<20} {'Size':<8} {'Context':<10} {'Description'}")
216
+ print("-" * 80)
217
+
218
+ for key, model in MODELS.items():
219
+ print(f"{key:<20} {model['size']:<8} {model['context']:<10} {model['description']}")
220
+
221
+
222
+ def main():
223
+ parser = argparse.ArgumentParser(description="Download and setup base models")
224
+ parser.add_argument("action", choices=["download", "convert", "verify", "list"],
225
+ help="Action to perform")
226
+ parser.add_argument("-m", "--model", choices=list(MODELS.keys()),
227
+ default="llama-3.1-8b", help="Model to download")
228
+ parser.add_argument("-f", "--format", choices=["mlx", "hf", "gguf", "all"],
229
+ default="mlx", help="Model format")
230
+ parser.add_argument("-o", "--output", default="models/base",
231
+ help="Output directory")
232
+ parser.add_argument("-q", "--quant", default="Q4_K_M",
233
+ help="GGUF quantization level")
234
+
235
+ args = parser.parse_args()
236
+
237
+ output_dir = Path(args.output)
238
+
239
+ if args.action == "list":
240
+ list_models()
241
+ return
242
+
243
+ if args.action == "download":
244
+ if args.format in ["mlx", "all"]:
245
+ download_mlx_model(args.model, output_dir)
246
+
247
+ if args.format in ["hf", "all"]:
248
+ download_hf_model(args.model, output_dir)
249
+
250
+ if args.format in ["gguf", "all"]:
251
+ download_gguf_model(args.model, output_dir, args.quant)
252
+
253
+ elif args.action == "convert":
254
+ hf_path = output_dir / args.model / "hf"
255
+ mlx_path = output_dir / args.model / "mlx-converted"
256
+ convert_to_mlx(hf_path, mlx_path)
257
+
258
+ elif args.action == "verify":
259
+ model_path = output_dir / args.model
260
+ if args.format == "mlx":
261
+ model_path = model_path / "mlx"
262
+ elif args.format == "hf":
263
+ model_path = model_path / "hf"
264
+
265
+ verify_model(model_path, args.format)
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()