File size: 6,557 Bytes
c6e2d82 fca51d8 e9688d3 9208e24 75fed13 9208e24 c6e2d82 007f931 75fed13 3396799 fca51d8 0b445a6 fca51d8 ce84fea 431e7f9 6f85038 fca51d8 9208e24 c6e2d82 fca51d8 9208e24 007f931 b3daae1 007f931 0b445a6 547ce4f 0b445a6 547ce4f fca51d8 0b445a6 75fed13 b3daae1 75fed13 fca51d8 b3daae1 e9688d3 b3daae1 007f931 b3daae1 007f931 e9688d3 007f931 b3daae1 007f931 e9688d3 b3daae1 007f931 b3daae1 26f5dc6 007f931 b3daae1 007f931 b3daae1 e9688d3 007f931 e9688d3 b3daae1 007f931 b3daae1 e9688d3 b3daae1 6f85038 b3daae1 007f931 0b445a6 b3daae1 9208e24 fca51d8 c6e2d82 625c5f1 c6e2d82 fca51d8 c6e2d82 007f931 547ce4f b3daae1 007f931 c6e2d82 fca51d8 c6e2d82 431e7f9 b3daae1 fca51d8 431e7f9 b3daae1 fca51d8 e9688d3 431e7f9 fca51d8 431e7f9 b3daae1 75fed13 547ce4f 0b445a6 b3daae1 547ce4f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import logging
import time
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from llama_cpp import Llama
import asyncio
import uvicorn
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Set up cache directory
CACHE_DIR = "/app/.cache/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR
# Create the FastAPI app
app = FastAPI(
title="MGZON Smart Assistant",
description="دمج نموذج T5 المدرب مع Mistral-7B (GGUF) داخل Space"
)
# Initialize model variables
t5_tokenizer = None
t5_model = None
mistral = None
t5_loaded = False
mistral_loaded = False
# Root endpoint
@app.get("/")
async def root():
logger.info(f"Root endpoint called at {time.time()}")
return JSONResponse(
content={"message": "MGZON Smart Assistant is running"},
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
# Health check endpoint
@app.get("/health")
async def health_check():
logger.info(f"Health check endpoint called at {time.time()}")
return JSONResponse(
content={"status": "healthy" if t5_loaded else "loading"},
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
# Async function to load T5 model
async def load_t5_model():
global t5_tokenizer, t5_model, t5_loaded
start_time = time.time()
logger.info(f"Starting T5 model loading at {start_time}")
try:
T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots")
logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
t5_tokenizer = AutoTokenizer.from_pretrained(
T5_MODEL_PATH,
local_files_only=True,
torch_dtype="float16" # Reduce memory usage
)
logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
t5_model = AutoModelForSeq2SeqLM.from_pretrained(
T5_MODEL_PATH,
local_files_only=True,
torch_dtype="float16" # Reduce memory usage
)
logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
t5_loaded = True
except Exception as e:
logger.error(f"Failed to load T5 model: {str(e)}", exc_info=True)
t5_loaded = False
raise RuntimeError(f"Failed to load T5 model: {str(e)}")
finally:
end_time = time.time()
logger.info(f"T5 model loading completed in {end_time - start_time} seconds")
# Async function to load Mistral model
async def load_mistral_model():
global mistral, mistral_loaded
start_time = time.time()
logger.info(f"Starting Mistral model loading at {start_time}")
try:
gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf")
if not os.path.exists(gguf_path):
logger.error(f"Mistral GGUF file not found at {gguf_path}")
raise RuntimeError(f"Mistral GGUF file not found at {gguf_path}")
logger.info(f"Loading Mistral model from {gguf_path}")
mistral = Llama(
model_path=gguf_path,
n_ctx=512,
n_threads=1,
n_batch=128,
verbose=True
)
logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds")
mistral_loaded = True
except Exception as e:
logger.error(f"Failed to load Mistral model: {str(e)}", exc_info=True)
mistral_loaded = False
raise RuntimeError(f"Failed to load Mistral model: {str(e)}")
finally:
end_time = time.time()
logger.info(f"Mistral model loading completed in {end_time - start_time} seconds")
# Run T5 model loading in the background
@app.on_event("startup")
async def startup_event():
logger.info(f"Startup event triggered at {time.time()}")
asyncio.create_task(load_t5_model()) # Load only T5 at startup
# Define request schema
class AskRequest(BaseModel):
question: str
max_new_tokens: int = 150
# Endpoint: /ask
@app.post("/ask")
async def ask(req: AskRequest):
logger.info(f"Received ask request: {req.question} at {time.time()}")
if not t5_loaded:
logger.error("T5 model not loaded yet")
raise HTTPException(status_code=503, detail="T5 model is still loading, please try again later")
q = req.question.strip()
if not q:
logger.error("Empty question received")
raise HTTPException(status_code=400, detail="Empty question")
try:
if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
# Use T5 model
logger.info("Using MGZON-FLAN-T5 model")
inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
model_name = "MGZON-FLAN-T5"
else:
# Load Mistral model if not loaded
if not mistral_loaded:
logger.info("Mistral model not loaded, loading now...")
await load_mistral_model()
if not mistral_loaded:
raise HTTPException(status_code=503, detail="Failed to load Mistral model")
# Use Mistral model
logger.info("Using Mistral-7B-GGUF model")
out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7)
answer = out["choices"][0]["text"].strip()
model_name = "Mistral-7B-GGUF"
logger.info(f"Response generated by {model_name}: {answer}")
return {"model": model_name, "response": answer}
except Exception as e:
logger.error(f"Error processing request: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")
# Run the app
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=8080,
log_level="info",
workers=1,
timeout_keep_alive=15,
limit_concurrency=5,
limit_max_requests=50
) |