|
|
|
|
|
""" |
|
|
Model Download and Setup for FinEE v2.0 |
|
|
======================================== |
|
|
|
|
|
Downloads and prepares base models for fine-tuning: |
|
|
- Llama 3.1 8B Instruct (Primary) |
|
|
- Qwen2.5 7B Instruct (Backup) |
|
|
|
|
|
Supports: |
|
|
- MLX format for Apple Silicon |
|
|
- PyTorch/Transformers format |
|
|
- GGUF for llama.cpp |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"llama-3.1-8b": { |
|
|
"hf_name": "meta-llama/Llama-3.1-8B-Instruct", |
|
|
"mlx_name": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", |
|
|
"gguf_name": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", |
|
|
"description": "Llama 3.1 8B Instruct - Best instruction-following", |
|
|
"size": "8B", |
|
|
"context": "128K", |
|
|
}, |
|
|
"qwen2.5-7b": { |
|
|
"hf_name": "Qwen/Qwen2.5-7B-Instruct", |
|
|
"mlx_name": "mlx-community/Qwen2.5-7B-Instruct-4bit", |
|
|
"gguf_name": "Qwen/Qwen2.5-7B-Instruct-GGUF", |
|
|
"description": "Qwen 2.5 7B - Excellent multilingual support", |
|
|
"size": "7B", |
|
|
"context": "128K", |
|
|
}, |
|
|
"mistral-7b": { |
|
|
"hf_name": "mistralai/Mistral-7B-Instruct-v0.3", |
|
|
"mlx_name": "mlx-community/Mistral-7B-Instruct-v0.3-4bit", |
|
|
"gguf_name": "bartowski/Mistral-7B-Instruct-v0.3-GGUF", |
|
|
"description": "Mistral 7B - Fast and efficient", |
|
|
"size": "7B", |
|
|
"context": "32K", |
|
|
}, |
|
|
"phi-3-medium": { |
|
|
"hf_name": "microsoft/Phi-3-medium-128k-instruct", |
|
|
"mlx_name": "mlx-community/Phi-3-medium-128k-instruct-4bit", |
|
|
"description": "Phi-3 Medium - Compact but powerful", |
|
|
"size": "14B", |
|
|
"context": "128K", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def download_mlx_model(model_key: str, output_dir: Path): |
|
|
"""Download model in MLX format.""" |
|
|
model = MODELS[model_key] |
|
|
mlx_name = model.get("mlx_name") |
|
|
|
|
|
if not mlx_name: |
|
|
print(f"β No MLX version available for {model_key}") |
|
|
return False |
|
|
|
|
|
print(f"\nπ₯ Downloading {model_key} (MLX format)...") |
|
|
print(f" From: {mlx_name}") |
|
|
|
|
|
output_path = output_dir / model_key / "mlx" |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
snapshot_download( |
|
|
repo_id=mlx_name, |
|
|
local_dir=str(output_path), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
print(f"β
Downloaded to: {output_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Download failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def download_hf_model(model_key: str, output_dir: Path): |
|
|
"""Download model in HuggingFace format.""" |
|
|
model = MODELS[model_key] |
|
|
hf_name = model["hf_name"] |
|
|
|
|
|
print(f"\nπ₯ Downloading {model_key} (HuggingFace format)...") |
|
|
print(f" From: {hf_name}") |
|
|
|
|
|
output_path = output_dir / model_key / "hf" |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
snapshot_download( |
|
|
repo_id=hf_name, |
|
|
local_dir=str(output_path), |
|
|
local_dir_use_symlinks=False, |
|
|
ignore_patterns=["*.bin", "*.h5"], |
|
|
) |
|
|
|
|
|
print(f"β
Downloaded to: {output_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Download failed: {e}") |
|
|
print(" Note: Some models require HuggingFace login") |
|
|
print(" Run: huggingface-cli login") |
|
|
return False |
|
|
|
|
|
|
|
|
def download_gguf_model(model_key: str, output_dir: Path, quant: str = "Q4_K_M"): |
|
|
"""Download GGUF quantized model.""" |
|
|
model = MODELS[model_key] |
|
|
gguf_name = model.get("gguf_name") |
|
|
|
|
|
if not gguf_name: |
|
|
print(f"β No GGUF version available for {model_key}") |
|
|
return False |
|
|
|
|
|
print(f"\nπ₯ Downloading {model_key} (GGUF {quant} format)...") |
|
|
print(f" From: {gguf_name}") |
|
|
|
|
|
output_path = output_dir / model_key / "gguf" |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
filename = f"*{quant}*.gguf" |
|
|
|
|
|
hf_hub_download( |
|
|
repo_id=gguf_name, |
|
|
filename=filename, |
|
|
local_dir=str(output_path), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
print(f"β
Downloaded to: {output_path}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Download failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def convert_to_mlx(model_path: Path, output_path: Path, quantize: bool = True): |
|
|
"""Convert HuggingFace model to MLX format.""" |
|
|
print(f"\nπ Converting to MLX format...") |
|
|
|
|
|
cmd = [ |
|
|
sys.executable, "-m", "mlx_lm.convert", |
|
|
"--hf-path", str(model_path), |
|
|
"--mlx-path", str(output_path), |
|
|
] |
|
|
|
|
|
if quantize: |
|
|
cmd.extend(["--quantize", "--q-bits", "4"]) |
|
|
|
|
|
try: |
|
|
subprocess.run(cmd, check=True) |
|
|
print(f"β
Converted to: {output_path}") |
|
|
return True |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"β Conversion failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def verify_model(model_path: Path, backend: str = "mlx"): |
|
|
"""Verify model can be loaded.""" |
|
|
print(f"\nπ Verifying model at {model_path}...") |
|
|
|
|
|
if backend == "mlx": |
|
|
try: |
|
|
from mlx_lm import load, generate |
|
|
|
|
|
model, tokenizer = load(str(model_path)) |
|
|
|
|
|
|
|
|
output = generate(model, tokenizer, "Hello", max_tokens=10) |
|
|
print(f"β
Model loaded successfully!") |
|
|
print(f" Test output: {output[:50]}...") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Verification failed: {e}") |
|
|
return False |
|
|
|
|
|
elif backend == "transformers": |
|
|
try: |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(str(model_path)) |
|
|
model = AutoModelForCausalLM.from_pretrained(str(model_path)) |
|
|
|
|
|
print(f"β
Model loaded successfully!") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Verification failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def list_models(): |
|
|
"""List available models.""" |
|
|
print("\nπ Available Models:\n") |
|
|
print(f"{'Model':<20} {'Size':<8} {'Context':<10} {'Description'}") |
|
|
print("-" * 80) |
|
|
|
|
|
for key, model in MODELS.items(): |
|
|
print(f"{key:<20} {model['size']:<8} {model['context']:<10} {model['description']}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Download and setup base models") |
|
|
parser.add_argument("action", choices=["download", "convert", "verify", "list"], |
|
|
help="Action to perform") |
|
|
parser.add_argument("-m", "--model", choices=list(MODELS.keys()), |
|
|
default="llama-3.1-8b", help="Model to download") |
|
|
parser.add_argument("-f", "--format", choices=["mlx", "hf", "gguf", "all"], |
|
|
default="mlx", help="Model format") |
|
|
parser.add_argument("-o", "--output", default="models/base", |
|
|
help="Output directory") |
|
|
parser.add_argument("-q", "--quant", default="Q4_K_M", |
|
|
help="GGUF quantization level") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
output_dir = Path(args.output) |
|
|
|
|
|
if args.action == "list": |
|
|
list_models() |
|
|
return |
|
|
|
|
|
if args.action == "download": |
|
|
if args.format in ["mlx", "all"]: |
|
|
download_mlx_model(args.model, output_dir) |
|
|
|
|
|
if args.format in ["hf", "all"]: |
|
|
download_hf_model(args.model, output_dir) |
|
|
|
|
|
if args.format in ["gguf", "all"]: |
|
|
download_gguf_model(args.model, output_dir, args.quant) |
|
|
|
|
|
elif args.action == "convert": |
|
|
hf_path = output_dir / args.model / "hf" |
|
|
mlx_path = output_dir / args.model / "mlx-converted" |
|
|
convert_to_mlx(hf_path, mlx_path) |
|
|
|
|
|
elif args.action == "verify": |
|
|
model_path = output_dir / args.model |
|
|
if args.format == "mlx": |
|
|
model_path = model_path / "mlx" |
|
|
elif args.format == "hf": |
|
|
model_path = model_path / "hf" |
|
|
|
|
|
verify_model(model_path, args.format) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|