|
|
import os |
|
|
import logging |
|
|
import time |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
from llama_cpp import Llama |
|
|
import asyncio |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
CACHE_DIR = "/app/.cache/huggingface/hub" |
|
|
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR |
|
|
os.environ["HF_HOME"] = CACHE_DIR |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="MGZON Smart Assistant", |
|
|
description="دمج نموذج T5 المدرب مع Mistral-7B (GGUF) داخل Space" |
|
|
) |
|
|
|
|
|
|
|
|
t5_tokenizer = None |
|
|
t5_model = None |
|
|
mistral = None |
|
|
t5_loaded = False |
|
|
mistral_loaded = False |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
logger.info(f"Root endpoint called at {time.time()}") |
|
|
return JSONResponse( |
|
|
content={"message": "MGZON Smart Assistant is running"}, |
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
logger.info(f"Health check endpoint called at {time.time()}") |
|
|
return JSONResponse( |
|
|
content={"status": "healthy" if t5_loaded else "loading"}, |
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} |
|
|
) |
|
|
|
|
|
|
|
|
async def load_t5_model(): |
|
|
global t5_tokenizer, t5_model, t5_loaded |
|
|
start_time = time.time() |
|
|
logger.info(f"Starting T5 model loading at {start_time}") |
|
|
try: |
|
|
T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots") |
|
|
logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}") |
|
|
t5_tokenizer = AutoTokenizer.from_pretrained( |
|
|
T5_MODEL_PATH, |
|
|
local_files_only=True, |
|
|
torch_dtype="float16" |
|
|
) |
|
|
logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds") |
|
|
logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}") |
|
|
t5_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
T5_MODEL_PATH, |
|
|
local_files_only=True, |
|
|
torch_dtype="float16" |
|
|
) |
|
|
logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds") |
|
|
t5_loaded = True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load T5 model: {str(e)}", exc_info=True) |
|
|
t5_loaded = False |
|
|
raise RuntimeError(f"Failed to load T5 model: {str(e)}") |
|
|
finally: |
|
|
end_time = time.time() |
|
|
logger.info(f"T5 model loading completed in {end_time - start_time} seconds") |
|
|
|
|
|
|
|
|
async def load_mistral_model(): |
|
|
global mistral, mistral_loaded |
|
|
start_time = time.time() |
|
|
logger.info(f"Starting Mistral model loading at {start_time}") |
|
|
try: |
|
|
gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf") |
|
|
if not os.path.exists(gguf_path): |
|
|
logger.error(f"Mistral GGUF file not found at {gguf_path}") |
|
|
raise RuntimeError(f"Mistral GGUF file not found at {gguf_path}") |
|
|
logger.info(f"Loading Mistral model from {gguf_path}") |
|
|
mistral = Llama( |
|
|
model_path=gguf_path, |
|
|
n_ctx=512, |
|
|
n_threads=1, |
|
|
n_batch=128, |
|
|
verbose=True |
|
|
) |
|
|
logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds") |
|
|
mistral_loaded = True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load Mistral model: {str(e)}", exc_info=True) |
|
|
mistral_loaded = False |
|
|
raise RuntimeError(f"Failed to load Mistral model: {str(e)}") |
|
|
finally: |
|
|
end_time = time.time() |
|
|
logger.info(f"Mistral model loading completed in {end_time - start_time} seconds") |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
logger.info(f"Startup event triggered at {time.time()}") |
|
|
asyncio.create_task(load_t5_model()) |
|
|
|
|
|
|
|
|
class AskRequest(BaseModel): |
|
|
question: str |
|
|
max_new_tokens: int = 150 |
|
|
|
|
|
|
|
|
@app.post("/ask") |
|
|
async def ask(req: AskRequest): |
|
|
logger.info(f"Received ask request: {req.question} at {time.time()}") |
|
|
if not t5_loaded: |
|
|
logger.error("T5 model not loaded yet") |
|
|
raise HTTPException(status_code=503, detail="T5 model is still loading, please try again later") |
|
|
|
|
|
q = req.question.strip() |
|
|
if not q: |
|
|
logger.error("Empty question received") |
|
|
raise HTTPException(status_code=400, detail="Empty question") |
|
|
|
|
|
try: |
|
|
if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]): |
|
|
|
|
|
logger.info("Using MGZON-FLAN-T5 model") |
|
|
inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256) |
|
|
out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens) |
|
|
answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True) |
|
|
model_name = "MGZON-FLAN-T5" |
|
|
else: |
|
|
|
|
|
if not mistral_loaded: |
|
|
logger.info("Mistral model not loaded, loading now...") |
|
|
await load_mistral_model() |
|
|
if not mistral_loaded: |
|
|
raise HTTPException(status_code=503, detail="Failed to load Mistral model") |
|
|
|
|
|
logger.info("Using Mistral-7B-GGUF model") |
|
|
out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7) |
|
|
answer = out["choices"][0]["text"].strip() |
|
|
model_name = "Mistral-7B-GGUF" |
|
|
logger.info(f"Response generated by {model_name}: {answer}") |
|
|
return {"model": model_name, "response": answer} |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing request: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=8080, |
|
|
log_level="info", |
|
|
workers=1, |
|
|
timeout_keep_alive=15, |
|
|
limit_concurrency=5, |
|
|
limit_max_requests=50 |
|
|
) |