AI_Chatbot / app_simple.py
LejobuildYT's picture
Update app_simple.py
e41af3e verified
#!/usr/bin/env python3
"""
Zephyr-7B Backend für HF Spaces
Frontend + Backend in EINEM Container (kein Vite-Drama!)
"""
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline
)
import logging
import time
from pathlib import Path
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Zephyr-7B - HF Spaces")
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== MODEL LOADING ==========
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_QUANTIZATION = True
def select_model():
"""Auto-select model based on available GPU memory"""
# Qwen 1.5B ist klein und schnell - nehmen wir immer das!
return "Qwen/Qwen2.5-1.5B-Instruct"
MODEL_NAME = os.getenv("MODEL_NAME", select_model())
logger.info(f"📌 Using model: {MODEL_NAME}")
def load_model_optimized():
"""Load with quantization"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if USE_QUANTIZATION and DEVICE == "cuda":
try:
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16,
bnb_8bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
)
logger.info("✅ Model loaded with 8-bit quantization")
except Exception as e:
logger.warning(f"⚠️ 8-bit failed: {e}, trying standard")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
else:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto" if DEVICE == "cuda" else None,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
return tokenizer, model
try:
logger.info(f"⏳ Loading {MODEL_NAME}...")
tokenizer, model = load_model_optimized()
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if DEVICE == "cuda" else -1,
)
logger.info("✅ Model ready!")
except Exception as e:
logger.error(f"❌ Model loading failed: {e}")
raise
# ========== API ENDPOINTS ==========
class GenerateRequest(BaseModel):
prompt: str
system_prompt: str = None
max_tokens: int = 1024 # 512
temperature: float = 0.7
top_p: float = 0.9
@app.post("/api/generate")
async def generate(request: GenerateRequest):
"""Generate text response"""
try:
start = time.time()
# Qwen prompt format: <|im_start|>role\ncontent\n<|im_end|>
messages = []
if request.system_prompt:
messages.append({"role": "system", "content": request.system_prompt})
messages.append({"role": "user", "content": request.prompt})
outputs = pipe(
messages,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
do_sample=True,
return_full_text=False,
)
response_text = outputs[0]["generated_text"].strip()
elapsed = time.time() - start
return {
"response": response_text,
"tokens": len(tokenizer.encode(response_text)),
"time_seconds": round(elapsed, 2),
"model": MODEL_NAME,
}
except Exception as e:
logger.error(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/health")
async def health():
"""Health check"""
return {
"status": "ok",
"model": MODEL_NAME,
"device": DEVICE,
}
@app.get("/api/info")
async def info():
"""Model info"""
gpu_memory = None
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
return {
"model": MODEL_NAME,
"device": DEVICE,
"gpu_memory_gb": gpu_memory,
"quantization": USE_QUANTIZATION,
}
# ========== STATIC FILES & FRONTEND ==========
@app.get("/")
async def serve_frontend():
"""Serve main page"""
return FileResponse("frontend.html", media_type="text/html")
@app.get("/{full_path:path}")
async def fallback(full_path: str):
"""Fallback for SPA routing"""
file_path = Path(full_path)
# Check if it's a static file
if file_path.exists():
return FileResponse(file_path)
# Otherwise serve frontend (SPA routing)
return FileResponse("frontend.html", media_type="text/html")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)