#!/usr/bin/env python3 """ Zephyr-7B Backend für HF Spaces Frontend + Backend in EINEM Container (kein Vite-Drama!) """ from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline ) import logging import time from pathlib import Path import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Zephyr-7B - HF Spaces") # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ========== MODEL LOADING ========== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" USE_QUANTIZATION = True def select_model(): """Auto-select model based on available GPU memory""" # Qwen 1.5B ist klein und schnell - nehmen wir immer das! return "Qwen/Qwen2.5-1.5B-Instruct" MODEL_NAME = os.getenv("MODEL_NAME", select_model()) logger.info(f"📌 Using model: {MODEL_NAME}") def load_model_optimized(): """Load with quantization""" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if USE_QUANTIZATION and DEVICE == "cuda": try: bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16, bnb_8bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map="auto", ) logger.info("✅ Model loaded with 8-bit quantization") except Exception as e: logger.warning(f"⚠️ 8-bit failed: {e}, trying standard") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto" if DEVICE == "cuda" else None, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, ) else: model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto" if DEVICE == "cuda" else None, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, ) return tokenizer, model try: logger.info(f"⏳ Loading {MODEL_NAME}...") tokenizer, model = load_model_optimized() pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=0 if DEVICE == "cuda" else -1, ) logger.info("✅ Model ready!") except Exception as e: logger.error(f"❌ Model loading failed: {e}") raise # ========== API ENDPOINTS ========== class GenerateRequest(BaseModel): prompt: str system_prompt: str = None max_tokens: int = 1024 # 512 temperature: float = 0.7 top_p: float = 0.9 @app.post("/api/generate") async def generate(request: GenerateRequest): """Generate text response""" try: start = time.time() # Qwen prompt format: <|im_start|>role\ncontent\n<|im_end|> messages = [] if request.system_prompt: messages.append({"role": "system", "content": request.system_prompt}) messages.append({"role": "user", "content": request.prompt}) outputs = pipe( messages, max_new_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, do_sample=True, return_full_text=False, ) response_text = outputs[0]["generated_text"].strip() elapsed = time.time() - start return { "response": response_text, "tokens": len(tokenizer.encode(response_text)), "time_seconds": round(elapsed, 2), "model": MODEL_NAME, } except Exception as e: logger.error(f"Generation error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/health") async def health(): """Health check""" return { "status": "ok", "model": MODEL_NAME, "device": DEVICE, } @app.get("/api/info") async def info(): """Model info""" gpu_memory = None if torch.cuda.is_available(): gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 return { "model": MODEL_NAME, "device": DEVICE, "gpu_memory_gb": gpu_memory, "quantization": USE_QUANTIZATION, } # ========== STATIC FILES & FRONTEND ========== @app.get("/") async def serve_frontend(): """Serve main page""" return FileResponse("frontend.html", media_type="text/html") @app.get("/{full_path:path}") async def fallback(full_path: str): """Fallback for SPA routing""" file_path = Path(full_path) # Check if it's a static file if file_path.exists(): return FileResponse(file_path) # Otherwise serve frontend (SPA routing) return FileResponse("frontend.html", media_type="text/html") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)