#!/usr/bin/env python3 """ Model Download and Setup for FinEE v2.0 ======================================== Downloads and prepares base models for fine-tuning: - Llama 3.1 8B Instruct (Primary) - Qwen2.5 7B Instruct (Backup) Supports: - MLX format for Apple Silicon - PyTorch/Transformers format - GGUF for llama.cpp """ import argparse import os import subprocess import sys from pathlib import Path MODELS = { "llama-3.1-8b": { "hf_name": "meta-llama/Llama-3.1-8B-Instruct", "mlx_name": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", "gguf_name": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "description": "Llama 3.1 8B Instruct - Best instruction-following", "size": "8B", "context": "128K", }, "qwen2.5-7b": { "hf_name": "Qwen/Qwen2.5-7B-Instruct", "mlx_name": "mlx-community/Qwen2.5-7B-Instruct-4bit", "gguf_name": "Qwen/Qwen2.5-7B-Instruct-GGUF", "description": "Qwen 2.5 7B - Excellent multilingual support", "size": "7B", "context": "128K", }, "mistral-7b": { "hf_name": "mistralai/Mistral-7B-Instruct-v0.3", "mlx_name": "mlx-community/Mistral-7B-Instruct-v0.3-4bit", "gguf_name": "bartowski/Mistral-7B-Instruct-v0.3-GGUF", "description": "Mistral 7B - Fast and efficient", "size": "7B", "context": "32K", }, "phi-3-medium": { "hf_name": "microsoft/Phi-3-medium-128k-instruct", "mlx_name": "mlx-community/Phi-3-medium-128k-instruct-4bit", "description": "Phi-3 Medium - Compact but powerful", "size": "14B", "context": "128K", }, } def download_mlx_model(model_key: str, output_dir: Path): """Download model in MLX format.""" model = MODELS[model_key] mlx_name = model.get("mlx_name") if not mlx_name: print(f"āŒ No MLX version available for {model_key}") return False print(f"\nšŸ“„ Downloading {model_key} (MLX format)...") print(f" From: {mlx_name}") output_path = output_dir / model_key / "mlx" output_path.mkdir(parents=True, exist_ok=True) try: from huggingface_hub import snapshot_download snapshot_download( repo_id=mlx_name, local_dir=str(output_path), local_dir_use_symlinks=False, ) print(f"āœ… Downloaded to: {output_path}") return True except Exception as e: print(f"āŒ Download failed: {e}") return False def download_hf_model(model_key: str, output_dir: Path): """Download model in HuggingFace format.""" model = MODELS[model_key] hf_name = model["hf_name"] print(f"\nšŸ“„ Downloading {model_key} (HuggingFace format)...") print(f" From: {hf_name}") output_path = output_dir / model_key / "hf" output_path.mkdir(parents=True, exist_ok=True) try: from huggingface_hub import snapshot_download snapshot_download( repo_id=hf_name, local_dir=str(output_path), local_dir_use_symlinks=False, ignore_patterns=["*.bin", "*.h5"], # Prefer safetensors ) print(f"āœ… Downloaded to: {output_path}") return True except Exception as e: print(f"āŒ Download failed: {e}") print(" Note: Some models require HuggingFace login") print(" Run: huggingface-cli login") return False def download_gguf_model(model_key: str, output_dir: Path, quant: str = "Q4_K_M"): """Download GGUF quantized model.""" model = MODELS[model_key] gguf_name = model.get("gguf_name") if not gguf_name: print(f"āŒ No GGUF version available for {model_key}") return False print(f"\nšŸ“„ Downloading {model_key} (GGUF {quant} format)...") print(f" From: {gguf_name}") output_path = output_dir / model_key / "gguf" output_path.mkdir(parents=True, exist_ok=True) try: from huggingface_hub import hf_hub_download # Find the right quantization file filename = f"*{quant}*.gguf" hf_hub_download( repo_id=gguf_name, filename=filename, local_dir=str(output_path), local_dir_use_symlinks=False, ) print(f"āœ… Downloaded to: {output_path}") return True except Exception as e: print(f"āŒ Download failed: {e}") return False def convert_to_mlx(model_path: Path, output_path: Path, quantize: bool = True): """Convert HuggingFace model to MLX format.""" print(f"\nšŸ”„ Converting to MLX format...") cmd = [ sys.executable, "-m", "mlx_lm.convert", "--hf-path", str(model_path), "--mlx-path", str(output_path), ] if quantize: cmd.extend(["--quantize", "--q-bits", "4"]) try: subprocess.run(cmd, check=True) print(f"āœ… Converted to: {output_path}") return True except subprocess.CalledProcessError as e: print(f"āŒ Conversion failed: {e}") return False def verify_model(model_path: Path, backend: str = "mlx"): """Verify model can be loaded.""" print(f"\nšŸ” Verifying model at {model_path}...") if backend == "mlx": try: from mlx_lm import load, generate model, tokenizer = load(str(model_path)) # Quick test output = generate(model, tokenizer, "Hello", max_tokens=10) print(f"āœ… Model loaded successfully!") print(f" Test output: {output[:50]}...") return True except Exception as e: print(f"āŒ Verification failed: {e}") return False elif backend == "transformers": try: from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(str(model_path)) model = AutoModelForCausalLM.from_pretrained(str(model_path)) print(f"āœ… Model loaded successfully!") return True except Exception as e: print(f"āŒ Verification failed: {e}") return False def list_models(): """List available models.""" print("\nšŸ“‹ Available Models:\n") print(f"{'Model':<20} {'Size':<8} {'Context':<10} {'Description'}") print("-" * 80) for key, model in MODELS.items(): print(f"{key:<20} {model['size']:<8} {model['context']:<10} {model['description']}") def main(): parser = argparse.ArgumentParser(description="Download and setup base models") parser.add_argument("action", choices=["download", "convert", "verify", "list"], help="Action to perform") parser.add_argument("-m", "--model", choices=list(MODELS.keys()), default="llama-3.1-8b", help="Model to download") parser.add_argument("-f", "--format", choices=["mlx", "hf", "gguf", "all"], default="mlx", help="Model format") parser.add_argument("-o", "--output", default="models/base", help="Output directory") parser.add_argument("-q", "--quant", default="Q4_K_M", help="GGUF quantization level") args = parser.parse_args() output_dir = Path(args.output) if args.action == "list": list_models() return if args.action == "download": if args.format in ["mlx", "all"]: download_mlx_model(args.model, output_dir) if args.format in ["hf", "all"]: download_hf_model(args.model, output_dir) if args.format in ["gguf", "all"]: download_gguf_model(args.model, output_dir, args.quant) elif args.action == "convert": hf_path = output_dir / args.model / "hf" mlx_path = output_dir / args.model / "mlx-converted" convert_to_mlx(hf_path, mlx_path) elif args.action == "verify": model_path = output_dir / args.model if args.format == "mlx": model_path = model_path / "mlx" elif args.format == "hf": model_path = model_path / "hf" verify_model(model_path, args.format) if __name__ == "__main__": main()