cardserver / app /core /model_loader.py
GitHub Actions
🚀 Auto-deploy from GitHub
198c5a7
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoConfig
from peft import PeftModel
import torch
import logging
from pathlib import Path
import os
import platform
from .config import settings, apply_hf_space_optimizations
from .hf_api import HuggingFaceWrapper
from functools import lru_cache
logger = logging.getLogger(__name__)
def optimize_for_hf_space():
"""Apply optimizations specific to Hugging Face Spaces"""
# Apply HF Space optimizations from config (includes cache dirs and other settings)
apply_hf_space_optimizations()
# Create cache directories
cache_dirs = ["/tmp/transformers_cache", "/tmp/huggingface", "/tmp/torch"]
for cache_dir in cache_dirs:
Path(cache_dir).mkdir(parents=True, exist_ok=True)
logger.info("🚀 Optimized cache directories for HF Space")
# Globale Variable für die Pipeline, um sie zwischenzuspeichern
# _cached_generator_pipeline = None # Entfernt, da wir lru_cache verwenden
def load_model_and_tokenizer():
"""
Optimierter Model Loader mit LoRA-Support.
Lädt Basismodell und Tokenizer.
Kann LoRA-Adapter von Hugging Face Hub herunterladen.
Automatische Konfiguration basierend auf verfügbaren Ressourcen.
"""
# Apply HF Space optimizations
optimize_for_hf_space()
# Check if we're on macOS and disable 4-bit quantization if needed
is_macos = platform.system() == "Darwin"
if is_macos and settings.MODEL_LOAD_IN_4BIT:
logger.warning("4-bit quantization is not recommended on macOS. Disabling 4-bit loading.")
use_4bit = False
else:
use_4bit = settings.MODEL_LOAD_IN_4BIT
# Only try to import bitsandbytes if we actually need 4-bit quantization
if use_4bit:
try:
from transformers import BitsAndBytesConfig
import bitsandbytes
logger.info(f"Successfully imported bitsandbytes version: {bitsandbytes.__version__}")
bitsandbytes_available = True
except ImportError as e:
logger.warning(f"Failed to import bitsandbytes: {e}. Disabling 4-bit quantization.")
bitsandbytes_available = False
use_4bit = False
except Exception as e:
# Catch other bitsandbytes related errors (like missing .dylib files)
logger.warning(f"Bitsandbytes import failed with error: {e}. Disabling 4-bit quantization.")
bitsandbytes_available = False
use_4bit = False
else:
bitsandbytes_available = False
if is_macos:
logger.info("Running on macOS - using standard model loading without 4-bit quantization.")
else:
logger.info("4-bit quantization is disabled in settings.")
base_model_id = settings.DEFAULT_MODEL_ID
hf_token = os.getenv("HF_API_KEY")
logger.info(f"Lade Basismodell und Tokenizer: {base_model_id}")
try:
# Try loading with fast tokenizer first
logger.info("Versuche Fast-Tokenizer zu laden...")
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
token=hf_token,
cache_dir="/tmp/transformers_cache"
)
except Exception as e:
logger.warning(f"Fast-Tokenizer-Loading fehlgeschlagen: {e}")
logger.info("Fallback auf Slow-Tokenizer...")
try:
# Fallback to slow tokenizer
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
token=hf_token,
use_fast=False,
cache_dir="/tmp/transformers_cache"
)
logger.info("✅ Slow-Tokenizer erfolgreich geladen.")
except Exception as e2:
logger.error(f"Auch Slow-Tokenizer fehlgeschlagen: {e2}")
logger.error(f"Ursprünglicher Fast-Tokenizer Fehler: {e}")
raise e2
if tokenizer.pad_token is None:
logger.info("Tokenizer hat kein pad_token. Setze pad_token = eos_token.")
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {
"device_map": "auto",
"trust_remote_code": True,
"token": hf_token,
"cache_dir": "/tmp/transformers_cache", # Use optimized cache directory
"low_cpu_mem_usage": True, # Reduce CPU memory usage during loading
}
logger.info(f"DEBUG: use_4bit={use_4bit}, bitsandbytes_available={bitsandbytes_available}")
if use_4bit and bitsandbytes_available:
try:
logger.info("Versuche, Modell mit 4-bit Quantisierung zu laden.")
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.float16
)
four_bit_model_kwargs = model_kwargs.copy()
four_bit_model_kwargs["quantization_config"] = quantization_config
model = AutoModelForCausalLM.from_pretrained(base_model_id, **four_bit_model_kwargs)
logger.info("Modell erfolgreich mit 4-bit Quantisierung geladen.")
except Exception as e:
logger.warning(f"4-bit Laden fehlgeschlagen: {e}. Fallback auf Standard-Laden (FP16).")
fallback_kwargs = model_kwargs.copy()
fallback_kwargs["torch_dtype"] = torch.float16
model = AutoModelForCausalLM.from_pretrained(base_model_id, **fallback_kwargs)
else:
logger.info("4-bit Quantisierung ist deaktiviert. Lade Modell in FP16.")
# Prepare kwargs for AutoConfig and AutoModelForCausalLM
shared_load_kwargs = {
"token": hf_token,
"trust_remote_code": True,
"cache_dir": "/tmp/transformers_cache"
}
# Load config first
try:
config = AutoConfig.from_pretrained(base_model_id, **shared_load_kwargs)
logger.info(f"Initial loaded config.parallelize_strategies: {getattr(config, 'parallelize_strategies', 'Not set')}")
# More comprehensive approach to handle parallelize_strategies
# Set it to an empty list if it's None or not set, as this seems to be safer
if not hasattr(config, 'parallelize_strategies') or config.parallelize_strategies is None:
config.parallelize_strategies = []
logger.info("Set config.parallelize_strategies to empty list []")
elif isinstance(config.parallelize_strategies, list):
# Clean any None values from the list
cleaned_strategies = [s for s in config.parallelize_strategies if s is not None]
if len(cleaned_strategies) != len(config.parallelize_strategies):
config.parallelize_strategies = cleaned_strategies
logger.info(f"Cleaned config.parallelize_strategies to: {config.parallelize_strategies}")
else:
logger.warning(f"config.parallelize_strategies is not a list: {config.parallelize_strategies}. Setting to empty list.")
config.parallelize_strategies = []
except Exception as e:
logger.error(f"Error loading or processing AutoConfig: {e}")
# If config loading fails, proceed without a modified config, which might lead to the original error
# but at least we tried.
config = None # Ensure model loading below doesn't fail on 'config' not defined
# Prepare kwargs for AutoModelForCausalLM.from_pretrained
final_fp16_model_kwargs = model_kwargs.copy() # Starts with device_map, trust_remote_code, token
final_fp16_model_kwargs["torch_dtype"] = torch.float16
if config: # Only add config if it was successfully loaded and processed
final_fp16_model_kwargs["config"] = config
model = AutoModelForCausalLM.from_pretrained(base_model_id, **final_fp16_model_kwargs)
# LoRA-Gewichte laden
lora_path_to_load = None
if settings.LORA_MODEL_REPO_ID:
logger.info(f"LoRA Adapter soll von Hugging Face Hub geladen werden: {settings.LORA_MODEL_REPO_ID}")
hf_wrapper = HuggingFaceWrapper(token=hf_token) # Token wird intern vom Wrapper geholt, falls nicht explizit übergeben
# Zielverzeichnis für heruntergeladene LoRA-Adapter
# Basierend auf MODEL_PATH aus settings, um Konsistenz zu wahren
# Beispiel: cardserver/models/lora-checkpoint/downloaded_adapters/your-lora-model-repo
local_lora_download_dir_base = settings.resolved_model_path.parent / "downloaded_adapters"
lora_adapter_name = settings.LORA_MODEL_REPO_ID.split("/")[-1] # z.B. "your-lora-model-repo"
local_lora_dir = local_lora_download_dir_base / lora_adapter_name
# Prüfen, ob der Adapter bereits heruntergeladen wurde (einfache Prüfung)
# Eine robustere Prüfung könnte Versions-Hashes oder Modifikationszeiten beinhalten.
adapter_config_file = local_lora_dir / "adapter_config.json"
if not adapter_config_file.exists() or getattr(settings, "LORA_FORCE_DOWNLOAD", False):
if adapter_config_file.exists():
logger.info(f"LORA_FORCE_DOWNLOAD ist aktiv. LoRA-Adapter wird erneut heruntergeladen: {settings.LORA_MODEL_REPO_ID}")
else:
logger.info(f"LoRA-Adapter nicht lokal gefunden unter {local_lora_dir}. Wird heruntergeladen...")
local_lora_dir.mkdir(parents=True, exist_ok=True) # Sicherstellen, dass das Verzeichnis existiert
try:
downloaded_path_str = hf_wrapper.download_model(
repo_name=settings.LORA_MODEL_REPO_ID,
local_dir=str(local_lora_dir), # Muss ein String sein
# revision=settings.LORA_MODEL_REVISION # Falls eine spezifische Version benötigt wird
)
lora_path_to_load = Path(downloaded_path_str) # Der Rückgabewert ist der Pfad
logger.info(f"LoRA-Adapter erfolgreich von {settings.LORA_MODEL_REPO_ID} nach {lora_path_to_load} heruntergeladen.")
except Exception as e:
logger.error(f"Fehler beim Herunterladen des LoRA-Adapters von {settings.LORA_MODEL_REPO_ID}: {e}")
logger.info("Versuche, Fallback auf lokalen Pfad (falls konfiguriert) oder verwende Basismodell.")
# Fallback auf settings.resolved_model_path, falls LORA_MODEL_REPO_ID fehlschlägt
if settings.resolved_model_path.exists() and (settings.resolved_model_path / "adapter_config.json").exists():
lora_path_to_load = settings.resolved_model_path
logger.info(f"Fallback auf lokalen LoRA-Pfad: {lora_path_to_load}")
else:
lora_path_to_load = None # Kein LoRA verwenden
else:
lora_path_to_load = local_lora_dir
logger.info(f"LoRA-Adapter {settings.LORA_MODEL_REPO_ID} bereits lokal vorhanden unter: {lora_path_to_load}")
elif settings.resolved_model_path.exists() and (settings.resolved_model_path / "adapter_config.json").exists():
# Fallback: Wenn LORA_MODEL_REPO_ID nicht gesetzt ist, aber ein lokaler Pfad existiert
lora_path_to_load = settings.resolved_model_path
logger.info(f"Verwende lokalen LoRA-Pfad: {lora_path_to_load} (da LORA_MODEL_REPO_ID nicht gesetzt).")
else:
logger.info("Kein LORA_MODEL_REPO_ID in den Settings und kein gültiger lokaler LoRA-Pfad gefunden.")
lora_path_to_load = None
if lora_path_to_load:
try:
logger.info(f"Versuche, LoRA-Gewichte von Pfad zu laden: {lora_path_to_load}")
model = PeftModel.from_pretrained(model, str(lora_path_to_load))
logger.info("✅ LoRA-Modell erfolgreich auf Basismodell angewendet.")
except Exception as e:
logger.error(f"❌ LoRA-Loading von {lora_path_to_load} fehlgeschlagen: {e}")
logger.info("Verwende Basismodell ohne LoRA-Adapter.")
else:
logger.info("Keine LoRA-Gewichte zum Laden spezifiziert oder gefunden. Verwende Basismodell.")
return model, tokenizer
@lru_cache(maxsize=None)
def get_generator():
"""
Lädt das Modell und den Tokenizer (beim ersten Aufruf)
und erstellt eine Textgenerierungs-Pipeline.
Die Pipeline wird gecacht.
"""
logger.info("Initialisiere Textgenerierungs-Pipeline...")
model, tokenizer = load_model_and_tokenizer()
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
logger.info(f"pad_token_id nicht im Tokenizer gefunden. Setze pad_token_id auf eos_token_id ({tokenizer.eos_token_id}).")
tokenizer.pad_token_id = tokenizer.eos_token_id
# Das Modell muss möglicherweise auch aktualisiert werden, wenn pad_token_id zur Laufzeit geändert wird
# Dies ist jedoch oft nicht notwendig, wenn das Modell bereits mit einem eos_token trainiert wurde.
# model.config.pad_token_id = tokenizer.pad_token_id
# When using device_map="auto" with accelerate, don't specify device for pipeline
# The pipeline will automatically use the same device mapping as the model
_cached_generator_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
# No device parameter when model uses device_map="auto"
)
logger.info(f"Textgenerierungs-Pipeline erfolgreich initialisiert. Model device mapping: {getattr(model, 'hf_device_map', 'No device map found')}")
return _cached_generator_pipeline
def get_model_info():
"""Informationen über das geladene Modell"""
lora_path = settings.resolved_model_path
return {
"base_model": settings.DEFAULT_MODEL_ID,
"lora_enabled": lora_path.exists(),
"lora_path": str(lora_path) if lora_path.exists() else None,
"gpu_available": torch.cuda.is_available(),
"gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
}
# Optional: Pre-load model at startup if desired (in main.py or similar)
# def preload_model():
# logger.info("Starte Pre-Loading des Modells...")
# get_generator()
# logger.info("Modell erfolgreich vorab geladen.")