AI_Chatbot / app_fastapi.py
LejobuildYT's picture
Update app_fastapi.py
84ed8f8 verified
"""
FastAPI Backend mit optimiertem Modell-Loading für HF Spaces
Support für Quantization und Memory-Limited Environments
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline
)
import logging
import time
from pathlib import Path
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Zephyr-7B API - Optimized")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model Config - OPTIMIERT FÜR HF SPACES
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_QUANTIZATION = True # 8-bit loading
# Wähle Modell basierend auf verfügbarem Memory
def select_model():
"""Wählt das beste Modell für verfügbares Memory"""
try:
# GPU Memory check
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
logger.info(f"GPU Memory: {gpu_memory:.1f}GB")
# Wähle Modell basierend auf Memory
if gpu_memory >= 20: # Genug für fp16
return "HuggingFaceH4/zephyr-7b-beta"
elif gpu_memory >= 10: # AWQ 4-bit
return "TheBloke/zephyr-7B-beta-AWQ"
else: # GGUF 4-bit (kompressester)
return "TheBloke/zephyr-7B-beta-GGUF"
else:
# CPU - nutze kleineres Modell
return "Qwen/Qwen2.5-1.5B-Instruct" # "HuggingFaceH4/zephyr-7b-alpha"
except Exception as e:
logger.warning(f"Could not detect GPU memory: {e}, using safe default")
return "Qwen/Qwen2.5-1.5B-Instruct"# "TheBloke/zephyr-7B-beta-AWQ"
MODEL_NAME = os.getenv("MODEL_NAME", select_model())
logger.info(f"Using model: {MODEL_NAME}")
# Initialize Model mit Quantization
logger.info(f"Loading model {MODEL_NAME} on {DEVICE}...")
def load_model_optimized():
"""Lädt Modell mit optimaler Quantization für HF Spaces"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Config für 8-bit Loading (spart 50% Memory!)
if USE_QUANTIZATION and DEVICE == "cuda":
try:
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
)
logger.info("✓ Model loaded with 8-bit quantization")
except Exception as e:
logger.warning(f"8-bit quantization failed: {e}, trying default")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
else:
# Standard Loading für CPU oder non-quantized
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
return tokenizer, model
try:
tokenizer, model = load_model_optimized()
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if DEVICE == "cuda" else -1,
)
logger.info("✓ Pipeline initialized successfully")
except Exception as e:
logger.error(f"✗ Failed to load model: {e}")
raise
# Request Model
class GenerateRequest(BaseModel):
prompt: str
system_prompt: str = None
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
@app.post("/api/generate")
async def generate(request: GenerateRequest):
"""Generate text - optimized for HF Spaces"""
try:
start = time.time()
# Format prompt
if request.system_prompt:
messages = f"<|system|>\n{request.system_prompt}\n<|user|>\n{request.prompt}\n<|assistant|>\n"
else:
messages = f"<|user|>\n{request.prompt}\n<|assistant|>\n"
# Generate
outputs = pipe(
messages,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
do_sample=True,
return_full_text=False,
)
response_text = outputs[0]["generated_text"].strip()
elapsed = time.time() - start
return {
"response": response_text,
"tokens": len(tokenizer.encode(response_text)),
"time_seconds": round(elapsed, 2),
"model": MODEL_NAME,
}
except Exception as e:
logger.error(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/health")
async def health():
"""Health check"""
return {
"status": "ok",
"model": MODEL_NAME,
"device": DEVICE,
"quantization": USE_QUANTIZATION,
}
@app.get("/api/info")
async def info():
"""Model info"""
gpu_memory = None
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
return {
"model": MODEL_NAME,
"device": DEVICE,
"gpu_memory_gb": gpu_memory,
"quantization_enabled": USE_QUANTIZATION,
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)