""" FastAPI Backend mit optimiertem Modell-Loading für HF Spaces Support für Quantization und Memory-Limited Environments """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline ) import logging import time from pathlib import Path import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Zephyr-7B API - Optimized") # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Model Config - OPTIMIERT FÜR HF SPACES DEVICE = "cuda" if torch.cuda.is_available() else "cpu" USE_QUANTIZATION = True # 8-bit loading # Wähle Modell basierend auf verfügbarem Memory def select_model(): """Wählt das beste Modell für verfügbares Memory""" try: # GPU Memory check if torch.cuda.is_available(): gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 logger.info(f"GPU Memory: {gpu_memory:.1f}GB") # Wähle Modell basierend auf Memory if gpu_memory >= 20: # Genug für fp16 return "HuggingFaceH4/zephyr-7b-beta" elif gpu_memory >= 10: # AWQ 4-bit return "TheBloke/zephyr-7B-beta-AWQ" else: # GGUF 4-bit (kompressester) return "TheBloke/zephyr-7B-beta-GGUF" else: # CPU - nutze kleineres Modell return "Qwen/Qwen2.5-1.5B-Instruct" # "HuggingFaceH4/zephyr-7b-alpha" except Exception as e: logger.warning(f"Could not detect GPU memory: {e}, using safe default") return "Qwen/Qwen2.5-1.5B-Instruct"# "TheBloke/zephyr-7B-beta-AWQ" MODEL_NAME = os.getenv("MODEL_NAME", select_model()) logger.info(f"Using model: {MODEL_NAME}") # Initialize Model mit Quantization logger.info(f"Loading model {MODEL_NAME} on {DEVICE}...") def load_model_optimized(): """Lädt Modell mit optimaler Quantization für HF Spaces""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Config für 8-bit Loading (spart 50% Memory!) if USE_QUANTIZATION and DEVICE == "cuda": try: bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16, bnb_8bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", ) logger.info("✓ Model loaded with 8-bit quantization") except Exception as e: logger.warning(f"8-bit quantization failed: {e}, trying default") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto" if DEVICE == "cuda" else None, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, ) else: # Standard Loading für CPU oder non-quantized model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto" if DEVICE == "cuda" else None, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, ) return tokenizer, model try: tokenizer, model = load_model_optimized() pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=0 if DEVICE == "cuda" else -1, ) logger.info("✓ Pipeline initialized successfully") except Exception as e: logger.error(f"✗ Failed to load model: {e}") raise # Request Model class GenerateRequest(BaseModel): prompt: str system_prompt: str = None max_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.9 @app.post("/api/generate") async def generate(request: GenerateRequest): """Generate text - optimized for HF Spaces""" try: start = time.time() # Format prompt if request.system_prompt: messages = f"<|system|>\n{request.system_prompt}\n<|user|>\n{request.prompt}\n<|assistant|>\n" else: messages = f"<|user|>\n{request.prompt}\n<|assistant|>\n" # Generate outputs = pipe( messages, max_new_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, do_sample=True, return_full_text=False, ) response_text = outputs[0]["generated_text"].strip() elapsed = time.time() - start return { "response": response_text, "tokens": len(tokenizer.encode(response_text)), "time_seconds": round(elapsed, 2), "model": MODEL_NAME, } except Exception as e: logger.error(f"Generation error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/health") async def health(): """Health check""" return { "status": "ok", "model": MODEL_NAME, "device": DEVICE, "quantization": USE_QUANTIZATION, } @app.get("/api/info") async def info(): """Model info""" gpu_memory = None if torch.cuda.is_available(): gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 return { "model": MODEL_NAME, "device": DEVICE, "gpu_memory_gb": gpu_memory, "quantization_enabled": USE_QUANTIZATION, } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)