finance-entity-extractor / scripts /download_base_model.py
Ranjit0034's picture
Upload scripts/download_base_model.py with huggingface_hub
a60c3fc verified
#!/usr/bin/env python3
"""
Model Download and Setup for FinEE v2.0
========================================
Downloads and prepares base models for fine-tuning:
- Llama 3.1 8B Instruct (Primary)
- Qwen2.5 7B Instruct (Backup)
Supports:
- MLX format for Apple Silicon
- PyTorch/Transformers format
- GGUF for llama.cpp
"""
import argparse
import os
import subprocess
import sys
from pathlib import Path
MODELS = {
"llama-3.1-8b": {
"hf_name": "meta-llama/Llama-3.1-8B-Instruct",
"mlx_name": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
"gguf_name": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
"description": "Llama 3.1 8B Instruct - Best instruction-following",
"size": "8B",
"context": "128K",
},
"qwen2.5-7b": {
"hf_name": "Qwen/Qwen2.5-7B-Instruct",
"mlx_name": "mlx-community/Qwen2.5-7B-Instruct-4bit",
"gguf_name": "Qwen/Qwen2.5-7B-Instruct-GGUF",
"description": "Qwen 2.5 7B - Excellent multilingual support",
"size": "7B",
"context": "128K",
},
"mistral-7b": {
"hf_name": "mistralai/Mistral-7B-Instruct-v0.3",
"mlx_name": "mlx-community/Mistral-7B-Instruct-v0.3-4bit",
"gguf_name": "bartowski/Mistral-7B-Instruct-v0.3-GGUF",
"description": "Mistral 7B - Fast and efficient",
"size": "7B",
"context": "32K",
},
"phi-3-medium": {
"hf_name": "microsoft/Phi-3-medium-128k-instruct",
"mlx_name": "mlx-community/Phi-3-medium-128k-instruct-4bit",
"description": "Phi-3 Medium - Compact but powerful",
"size": "14B",
"context": "128K",
},
}
def download_mlx_model(model_key: str, output_dir: Path):
"""Download model in MLX format."""
model = MODELS[model_key]
mlx_name = model.get("mlx_name")
if not mlx_name:
print(f"❌ No MLX version available for {model_key}")
return False
print(f"\nπŸ“₯ Downloading {model_key} (MLX format)...")
print(f" From: {mlx_name}")
output_path = output_dir / model_key / "mlx"
output_path.mkdir(parents=True, exist_ok=True)
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=mlx_name,
local_dir=str(output_path),
local_dir_use_symlinks=False,
)
print(f"βœ… Downloaded to: {output_path}")
return True
except Exception as e:
print(f"❌ Download failed: {e}")
return False
def download_hf_model(model_key: str, output_dir: Path):
"""Download model in HuggingFace format."""
model = MODELS[model_key]
hf_name = model["hf_name"]
print(f"\nπŸ“₯ Downloading {model_key} (HuggingFace format)...")
print(f" From: {hf_name}")
output_path = output_dir / model_key / "hf"
output_path.mkdir(parents=True, exist_ok=True)
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=hf_name,
local_dir=str(output_path),
local_dir_use_symlinks=False,
ignore_patterns=["*.bin", "*.h5"], # Prefer safetensors
)
print(f"βœ… Downloaded to: {output_path}")
return True
except Exception as e:
print(f"❌ Download failed: {e}")
print(" Note: Some models require HuggingFace login")
print(" Run: huggingface-cli login")
return False
def download_gguf_model(model_key: str, output_dir: Path, quant: str = "Q4_K_M"):
"""Download GGUF quantized model."""
model = MODELS[model_key]
gguf_name = model.get("gguf_name")
if not gguf_name:
print(f"❌ No GGUF version available for {model_key}")
return False
print(f"\nπŸ“₯ Downloading {model_key} (GGUF {quant} format)...")
print(f" From: {gguf_name}")
output_path = output_dir / model_key / "gguf"
output_path.mkdir(parents=True, exist_ok=True)
try:
from huggingface_hub import hf_hub_download
# Find the right quantization file
filename = f"*{quant}*.gguf"
hf_hub_download(
repo_id=gguf_name,
filename=filename,
local_dir=str(output_path),
local_dir_use_symlinks=False,
)
print(f"βœ… Downloaded to: {output_path}")
return True
except Exception as e:
print(f"❌ Download failed: {e}")
return False
def convert_to_mlx(model_path: Path, output_path: Path, quantize: bool = True):
"""Convert HuggingFace model to MLX format."""
print(f"\nπŸ”„ Converting to MLX format...")
cmd = [
sys.executable, "-m", "mlx_lm.convert",
"--hf-path", str(model_path),
"--mlx-path", str(output_path),
]
if quantize:
cmd.extend(["--quantize", "--q-bits", "4"])
try:
subprocess.run(cmd, check=True)
print(f"βœ… Converted to: {output_path}")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Conversion failed: {e}")
return False
def verify_model(model_path: Path, backend: str = "mlx"):
"""Verify model can be loaded."""
print(f"\nπŸ” Verifying model at {model_path}...")
if backend == "mlx":
try:
from mlx_lm import load, generate
model, tokenizer = load(str(model_path))
# Quick test
output = generate(model, tokenizer, "Hello", max_tokens=10)
print(f"βœ… Model loaded successfully!")
print(f" Test output: {output[:50]}...")
return True
except Exception as e:
print(f"❌ Verification failed: {e}")
return False
elif backend == "transformers":
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(str(model_path))
model = AutoModelForCausalLM.from_pretrained(str(model_path))
print(f"βœ… Model loaded successfully!")
return True
except Exception as e:
print(f"❌ Verification failed: {e}")
return False
def list_models():
"""List available models."""
print("\nπŸ“‹ Available Models:\n")
print(f"{'Model':<20} {'Size':<8} {'Context':<10} {'Description'}")
print("-" * 80)
for key, model in MODELS.items():
print(f"{key:<20} {model['size']:<8} {model['context']:<10} {model['description']}")
def main():
parser = argparse.ArgumentParser(description="Download and setup base models")
parser.add_argument("action", choices=["download", "convert", "verify", "list"],
help="Action to perform")
parser.add_argument("-m", "--model", choices=list(MODELS.keys()),
default="llama-3.1-8b", help="Model to download")
parser.add_argument("-f", "--format", choices=["mlx", "hf", "gguf", "all"],
default="mlx", help="Model format")
parser.add_argument("-o", "--output", default="models/base",
help="Output directory")
parser.add_argument("-q", "--quant", default="Q4_K_M",
help="GGUF quantization level")
args = parser.parse_args()
output_dir = Path(args.output)
if args.action == "list":
list_models()
return
if args.action == "download":
if args.format in ["mlx", "all"]:
download_mlx_model(args.model, output_dir)
if args.format in ["hf", "all"]:
download_hf_model(args.model, output_dir)
if args.format in ["gguf", "all"]:
download_gguf_model(args.model, output_dir, args.quant)
elif args.action == "convert":
hf_path = output_dir / args.model / "hf"
mlx_path = output_dir / args.model / "mlx-converted"
convert_to_mlx(hf_path, mlx_path)
elif args.action == "verify":
model_path = output_dir / args.model
if args.format == "mlx":
model_path = model_path / "mlx"
elif args.format == "hf":
model_path = model_path / "hf"
verify_model(model_path, args.format)
if __name__ == "__main__":
main()