api-mg / app.py
MGZON's picture
Update app.py
b3daae1 verified
import os
import logging
import time
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from llama_cpp import Llama
import asyncio
import uvicorn
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set up cache directory
CACHE_DIR = "/app/.cache/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR
# Create the FastAPI app
app = FastAPI(
title="MGZON Smart Assistant",
description="دمج نموذج T5 المدرب مع Mistral-7B (GGUF) داخل Space"
)
# Initialize model variables
t5_tokenizer = None
t5_model = None
mistral = None
t5_loaded = False
mistral_loaded = False
# Root endpoint
@app.get("/")
async def root():
logger.info(f"Root endpoint called at {time.time()}")
return JSONResponse(
content={"message": "MGZON Smart Assistant is running"},
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
# Health check endpoint
@app.get("/health")
async def health_check():
logger.info(f"Health check endpoint called at {time.time()}")
return JSONResponse(
content={"status": "healthy" if t5_loaded else "loading"},
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
# Async function to load T5 model
async def load_t5_model():
global t5_tokenizer, t5_model, t5_loaded
start_time = time.time()
logger.info(f"Starting T5 model loading at {start_time}")
try:
T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots")
logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
t5_tokenizer = AutoTokenizer.from_pretrained(
T5_MODEL_PATH,
local_files_only=True,
torch_dtype="float16" # Reduce memory usage
)
logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
t5_model = AutoModelForSeq2SeqLM.from_pretrained(
T5_MODEL_PATH,
local_files_only=True,
torch_dtype="float16" # Reduce memory usage
)
logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
t5_loaded = True
except Exception as e:
logger.error(f"Failed to load T5 model: {str(e)}", exc_info=True)
t5_loaded = False
raise RuntimeError(f"Failed to load T5 model: {str(e)}")
finally:
end_time = time.time()
logger.info(f"T5 model loading completed in {end_time - start_time} seconds")
# Async function to load Mistral model
async def load_mistral_model():
global mistral, mistral_loaded
start_time = time.time()
logger.info(f"Starting Mistral model loading at {start_time}")
try:
gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf")
if not os.path.exists(gguf_path):
logger.error(f"Mistral GGUF file not found at {gguf_path}")
raise RuntimeError(f"Mistral GGUF file not found at {gguf_path}")
logger.info(f"Loading Mistral model from {gguf_path}")
mistral = Llama(
model_path=gguf_path,
n_ctx=512,
n_threads=1,
n_batch=128,
verbose=True
)
logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds")
mistral_loaded = True
except Exception as e:
logger.error(f"Failed to load Mistral model: {str(e)}", exc_info=True)
mistral_loaded = False
raise RuntimeError(f"Failed to load Mistral model: {str(e)}")
finally:
end_time = time.time()
logger.info(f"Mistral model loading completed in {end_time - start_time} seconds")
# Run T5 model loading in the background
@app.on_event("startup")
async def startup_event():
logger.info(f"Startup event triggered at {time.time()}")
asyncio.create_task(load_t5_model()) # Load only T5 at startup
# Define request schema
class AskRequest(BaseModel):
question: str
max_new_tokens: int = 150
# Endpoint: /ask
@app.post("/ask")
async def ask(req: AskRequest):
logger.info(f"Received ask request: {req.question} at {time.time()}")
if not t5_loaded:
logger.error("T5 model not loaded yet")
raise HTTPException(status_code=503, detail="T5 model is still loading, please try again later")
q = req.question.strip()
if not q:
logger.error("Empty question received")
raise HTTPException(status_code=400, detail="Empty question")
try:
if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
# Use T5 model
logger.info("Using MGZON-FLAN-T5 model")
inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
model_name = "MGZON-FLAN-T5"
else:
# Load Mistral model if not loaded
if not mistral_loaded:
logger.info("Mistral model not loaded, loading now...")
await load_mistral_model()
if not mistral_loaded:
raise HTTPException(status_code=503, detail="Failed to load Mistral model")
# Use Mistral model
logger.info("Using Mistral-7B-GGUF model")
out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7)
answer = out["choices"][0]["text"].strip()
model_name = "Mistral-7B-GGUF"
logger.info(f"Response generated by {model_name}: {answer}")
return {"model": model_name, "response": answer}
except Exception as e:
logger.error(f"Error processing request: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")
# Run the app
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=8080,
log_level="info",
workers=1,
timeout_keep_alive=15,
limit_concurrency=5,
limit_max_requests=50
)