File size: 14,325 Bytes
876e650
a38f710
f6e3d73
 
 
 
876e650
198c5a7
f6e3d73
876e650
a38f710
f6e3d73
a38f710
198c5a7
 
 
 
 
 
 
 
 
 
 
 
0d87629
 
 
876e650
f6e3d73
 
0d87629
f6e3d73
 
 
198c5a7
 
 
876e650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aced7f3
 
 
 
 
876e650
 
 
 
 
 
1c41e91
f6e3d73
 
a38f710
0d87629
198c5a7
a38f710
50882fa
 
198c5a7
 
 
 
 
f6e3d73
50882fa
 
 
 
 
 
 
198c5a7
 
50882fa
 
 
 
 
 
a38f710
f6e3d73
 
 
 
 
 
0d87629
198c5a7
 
 
f6e3d73
 
876e650
 
 
f6e3d73
 
876e650
1c41e91
 
 
 
 
 
 
 
 
 
f6e3d73
 
 
1c41e91
876e650
1c41e91
f6e3d73
 
1c41e91
 
 
 
198c5a7
 
1c41e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6e3d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d87629
 
 
1c41e91
0d87629
 
 
 
 
 
1c41e91
0d87629
1c41e91
0d87629
 
 
 
 
 
 
 
1c41e91
 
0d87629
f6e3d73
 
1c41e91
 
f6e3d73
1c41e91
0d87629
a38f710
f6e3d73
 
 
 
 
 
 
 
 
 
 
 
0d87629
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoConfig
from peft import PeftModel
import torch
import logging
from pathlib import Path
import os
import platform
from .config import settings, apply_hf_space_optimizations
from .hf_api import HuggingFaceWrapper
from functools import lru_cache

logger = logging.getLogger(__name__)

def optimize_for_hf_space():
    """Apply optimizations specific to Hugging Face Spaces"""
    # Apply HF Space optimizations from config (includes cache dirs and other settings)
    apply_hf_space_optimizations()
    
    # Create cache directories
    cache_dirs = ["/tmp/transformers_cache", "/tmp/huggingface", "/tmp/torch"]
    for cache_dir in cache_dirs:
        Path(cache_dir).mkdir(parents=True, exist_ok=True)
    
    logger.info("🚀 Optimized cache directories for HF Space")

# Globale Variable für die Pipeline, um sie zwischenzuspeichern
# _cached_generator_pipeline = None # Entfernt, da wir lru_cache verwenden

def load_model_and_tokenizer():
    """
    Optimierter Model Loader mit LoRA-Support.
    Lädt Basismodell und Tokenizer.
    Kann LoRA-Adapter von Hugging Face Hub herunterladen.
    Automatische Konfiguration basierend auf verfügbaren Ressourcen.
    """
    # Apply HF Space optimizations
    optimize_for_hf_space()
    
    # Check if we're on macOS and disable 4-bit quantization if needed
    is_macos = platform.system() == "Darwin"
    if is_macos and settings.MODEL_LOAD_IN_4BIT:
        logger.warning("4-bit quantization is not recommended on macOS. Disabling 4-bit loading.")
        use_4bit = False
    else:
        use_4bit = settings.MODEL_LOAD_IN_4BIT
    
    # Only try to import bitsandbytes if we actually need 4-bit quantization
    if use_4bit:
        try:
            from transformers import BitsAndBytesConfig
            import bitsandbytes
            logger.info(f"Successfully imported bitsandbytes version: {bitsandbytes.__version__}")
            bitsandbytes_available = True
        except ImportError as e:
            logger.warning(f"Failed to import bitsandbytes: {e}. Disabling 4-bit quantization.")
            bitsandbytes_available = False
            use_4bit = False
        except Exception as e:
            # Catch other bitsandbytes related errors (like missing .dylib files)
            logger.warning(f"Bitsandbytes import failed with error: {e}. Disabling 4-bit quantization.")
            bitsandbytes_available = False
            use_4bit = False
    else:
        bitsandbytes_available = False
        if is_macos:
            logger.info("Running on macOS - using standard model loading without 4-bit quantization.")
        else:
            logger.info("4-bit quantization is disabled in settings.")

    base_model_id = settings.DEFAULT_MODEL_ID
    hf_token = os.getenv("HF_API_KEY")

    logger.info(f"Lade Basismodell und Tokenizer: {base_model_id}")
    
    try:
        # Try loading with fast tokenizer first
        logger.info("Versuche Fast-Tokenizer zu laden...")
        tokenizer = AutoTokenizer.from_pretrained(
            base_model_id, 
            token=hf_token,
            cache_dir="/tmp/transformers_cache"
        )
    except Exception as e:
        logger.warning(f"Fast-Tokenizer-Loading fehlgeschlagen: {e}")
        logger.info("Fallback auf Slow-Tokenizer...")
        try:
            # Fallback to slow tokenizer
            tokenizer = AutoTokenizer.from_pretrained(
                base_model_id, 
                token=hf_token,
                use_fast=False,
                cache_dir="/tmp/transformers_cache"
            )
            logger.info("✅ Slow-Tokenizer erfolgreich geladen.")
        except Exception as e2:
            logger.error(f"Auch Slow-Tokenizer fehlgeschlagen: {e2}")
            logger.error(f"Ursprünglicher Fast-Tokenizer Fehler: {e}")
            raise e2

    if tokenizer.pad_token is None:
        logger.info("Tokenizer hat kein pad_token. Setze pad_token = eos_token.")
        tokenizer.pad_token = tokenizer.eos_token

    model_kwargs = {
        "device_map": "auto",
        "trust_remote_code": True,
        "token": hf_token,
        "cache_dir": "/tmp/transformers_cache",  # Use optimized cache directory
        "low_cpu_mem_usage": True,  # Reduce CPU memory usage during loading
    }

    logger.info(f"DEBUG: use_4bit={use_4bit}, bitsandbytes_available={bitsandbytes_available}")
    
    if use_4bit and bitsandbytes_available:
        try:
            logger.info("Versuche, Modell mit 4-bit Quantisierung zu laden.")
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=False,
                bnb_4bit_compute_dtype=torch.float16
            )
            four_bit_model_kwargs = model_kwargs.copy()
            four_bit_model_kwargs["quantization_config"] = quantization_config

            model = AutoModelForCausalLM.from_pretrained(base_model_id, **four_bit_model_kwargs)
            logger.info("Modell erfolgreich mit 4-bit Quantisierung geladen.")
        except Exception as e:
            logger.warning(f"4-bit Laden fehlgeschlagen: {e}. Fallback auf Standard-Laden (FP16).")
            fallback_kwargs = model_kwargs.copy()
            fallback_kwargs["torch_dtype"] = torch.float16
            model = AutoModelForCausalLM.from_pretrained(base_model_id, **fallback_kwargs)
    else:
        logger.info("4-bit Quantisierung ist deaktiviert. Lade Modell in FP16.")

        # Prepare kwargs for AutoConfig and AutoModelForCausalLM
        shared_load_kwargs = {
            "token": hf_token,
            "trust_remote_code": True,
            "cache_dir": "/tmp/transformers_cache"
        }

        # Load config first
        try:
            config = AutoConfig.from_pretrained(base_model_id, **shared_load_kwargs)
            logger.info(f"Initial loaded config.parallelize_strategies: {getattr(config, 'parallelize_strategies', 'Not set')}")

            # More comprehensive approach to handle parallelize_strategies
            # Set it to an empty list if it's None or not set, as this seems to be safer
            if not hasattr(config, 'parallelize_strategies') or config.parallelize_strategies is None:
                config.parallelize_strategies = []
                logger.info("Set config.parallelize_strategies to empty list []")
            elif isinstance(config.parallelize_strategies, list):
                # Clean any None values from the list
                cleaned_strategies = [s for s in config.parallelize_strategies if s is not None]
                if len(cleaned_strategies) != len(config.parallelize_strategies):
                    config.parallelize_strategies = cleaned_strategies
                    logger.info(f"Cleaned config.parallelize_strategies to: {config.parallelize_strategies}")
            else:
                logger.warning(f"config.parallelize_strategies is not a list: {config.parallelize_strategies}. Setting to empty list.")
                config.parallelize_strategies = []
        except Exception as e:
            logger.error(f"Error loading or processing AutoConfig: {e}")
            # If config loading fails, proceed without a modified config, which might lead to the original error
            # but at least we tried.
            config = None # Ensure model loading below doesn't fail on 'config' not defined

        # Prepare kwargs for AutoModelForCausalLM.from_pretrained
        final_fp16_model_kwargs = model_kwargs.copy() # Starts with device_map, trust_remote_code, token
        final_fp16_model_kwargs["torch_dtype"] = torch.float16
        if config: # Only add config if it was successfully loaded and processed
            final_fp16_model_kwargs["config"] = config
        
        model = AutoModelForCausalLM.from_pretrained(base_model_id, **final_fp16_model_kwargs)

    # LoRA-Gewichte laden
    lora_path_to_load = None
    if settings.LORA_MODEL_REPO_ID:
        logger.info(f"LoRA Adapter soll von Hugging Face Hub geladen werden: {settings.LORA_MODEL_REPO_ID}")
        hf_wrapper = HuggingFaceWrapper(token=hf_token)  # Token wird intern vom Wrapper geholt, falls nicht explizit übergeben

        # Zielverzeichnis für heruntergeladene LoRA-Adapter
        # Basierend auf MODEL_PATH aus settings, um Konsistenz zu wahren
        # Beispiel: cardserver/models/lora-checkpoint/downloaded_adapters/your-lora-model-repo
        local_lora_download_dir_base = settings.resolved_model_path.parent / "downloaded_adapters"
        lora_adapter_name = settings.LORA_MODEL_REPO_ID.split("/")[-1]  # z.B. "your-lora-model-repo"
        local_lora_dir = local_lora_download_dir_base / lora_adapter_name

        # Prüfen, ob der Adapter bereits heruntergeladen wurde (einfache Prüfung)
        # Eine robustere Prüfung könnte Versions-Hashes oder Modifikationszeiten beinhalten.
        adapter_config_file = local_lora_dir / "adapter_config.json"

        if not adapter_config_file.exists() or getattr(settings, "LORA_FORCE_DOWNLOAD", False):
            if adapter_config_file.exists():
                logger.info(f"LORA_FORCE_DOWNLOAD ist aktiv. LoRA-Adapter wird erneut heruntergeladen: {settings.LORA_MODEL_REPO_ID}")
            else:
                logger.info(f"LoRA-Adapter nicht lokal gefunden unter {local_lora_dir}. Wird heruntergeladen...")

            local_lora_dir.mkdir(parents=True, exist_ok=True)  # Sicherstellen, dass das Verzeichnis existiert
            try:
                downloaded_path_str = hf_wrapper.download_model(
                    repo_name=settings.LORA_MODEL_REPO_ID,
                    local_dir=str(local_lora_dir),  # Muss ein String sein
                    # revision=settings.LORA_MODEL_REVISION # Falls eine spezifische Version benötigt wird
                )
                lora_path_to_load = Path(downloaded_path_str)  # Der Rückgabewert ist der Pfad
                logger.info(f"LoRA-Adapter erfolgreich von {settings.LORA_MODEL_REPO_ID} nach {lora_path_to_load} heruntergeladen.")
            except Exception as e:
                logger.error(f"Fehler beim Herunterladen des LoRA-Adapters von {settings.LORA_MODEL_REPO_ID}: {e}")
                logger.info("Versuche, Fallback auf lokalen Pfad (falls konfiguriert) oder verwende Basismodell.")
                # Fallback auf settings.resolved_model_path, falls LORA_MODEL_REPO_ID fehlschlägt
                if settings.resolved_model_path.exists() and (settings.resolved_model_path / "adapter_config.json").exists():
                    lora_path_to_load = settings.resolved_model_path
                    logger.info(f"Fallback auf lokalen LoRA-Pfad: {lora_path_to_load}")
                else:
                    lora_path_to_load = None  # Kein LoRA verwenden
        else:
            lora_path_to_load = local_lora_dir
            logger.info(f"LoRA-Adapter {settings.LORA_MODEL_REPO_ID} bereits lokal vorhanden unter: {lora_path_to_load}")

    elif settings.resolved_model_path.exists() and (settings.resolved_model_path / "adapter_config.json").exists():
        # Fallback: Wenn LORA_MODEL_REPO_ID nicht gesetzt ist, aber ein lokaler Pfad existiert
        lora_path_to_load = settings.resolved_model_path
        logger.info(f"Verwende lokalen LoRA-Pfad: {lora_path_to_load} (da LORA_MODEL_REPO_ID nicht gesetzt).")
    else:
        logger.info("Kein LORA_MODEL_REPO_ID in den Settings und kein gültiger lokaler LoRA-Pfad gefunden.")
        lora_path_to_load = None

    if lora_path_to_load:
        try:
            logger.info(f"Versuche, LoRA-Gewichte von Pfad zu laden: {lora_path_to_load}")
            model = PeftModel.from_pretrained(model, str(lora_path_to_load))
            logger.info("✅ LoRA-Modell erfolgreich auf Basismodell angewendet.")
        except Exception as e:
            logger.error(f"❌ LoRA-Loading von {lora_path_to_load} fehlgeschlagen: {e}")
            logger.info("Verwende Basismodell ohne LoRA-Adapter.")
    else:
        logger.info("Keine LoRA-Gewichte zum Laden spezifiziert oder gefunden. Verwende Basismodell.")

    return model, tokenizer


@lru_cache(maxsize=None)
def get_generator():
    """
    Lädt das Modell und den Tokenizer (beim ersten Aufruf) 
    und erstellt eine Textgenerierungs-Pipeline.
    Die Pipeline wird gecacht.
    """

    logger.info("Initialisiere Textgenerierungs-Pipeline...")
    model, tokenizer = load_model_and_tokenizer() 
    
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        logger.info(f"pad_token_id nicht im Tokenizer gefunden. Setze pad_token_id auf eos_token_id ({tokenizer.eos_token_id}).")
        tokenizer.pad_token_id = tokenizer.eos_token_id
        # Das Modell muss möglicherweise auch aktualisiert werden, wenn pad_token_id zur Laufzeit geändert wird
        # Dies ist jedoch oft nicht notwendig, wenn das Modell bereits mit einem eos_token trainiert wurde.
        # model.config.pad_token_id = tokenizer.pad_token_id

    # When using device_map="auto" with accelerate, don't specify device for pipeline
    # The pipeline will automatically use the same device mapping as the model
    _cached_generator_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer
        # No device parameter when model uses device_map="auto"
    )
    logger.info(f"Textgenerierungs-Pipeline erfolgreich initialisiert. Model device mapping: {getattr(model, 'hf_device_map', 'No device map found')}")
    return _cached_generator_pipeline


def get_model_info():
    """Informationen über das geladene Modell"""
    lora_path = settings.resolved_model_path
    return {
        "base_model": settings.DEFAULT_MODEL_ID,
        "lora_enabled": lora_path.exists(),
        "lora_path": str(lora_path) if lora_path.exists() else None,
        "gpu_available": torch.cuda.is_available(),
        "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
    }

# Optional: Pre-load model at startup if desired (in main.py or similar)
# def preload_model():
#     logger.info("Starte Pre-Loading des Modells...")
#     get_generator()
#     logger.info("Modell erfolgreich vorab geladen.")