File size: 6,557 Bytes
c6e2d82
fca51d8
e9688d3
9208e24
75fed13
9208e24
c6e2d82
 
007f931
75fed13
3396799
fca51d8
0b445a6
fca51d8
 
 
ce84fea
 
431e7f9
6f85038
fca51d8
9208e24
c6e2d82
fca51d8
9208e24
 
007f931
 
 
 
b3daae1
 
007f931
0b445a6
547ce4f
 
0b445a6
547ce4f
 
 
 
 
fca51d8
 
 
0b445a6
75fed13
b3daae1
75fed13
 
fca51d8
b3daae1
 
 
e9688d3
b3daae1
007f931
 
 
 
 
b3daae1
 
007f931
e9688d3
007f931
 
 
b3daae1
 
007f931
e9688d3
b3daae1
 
 
 
 
 
 
 
007f931
b3daae1
 
 
 
 
 
26f5dc6
007f931
 
b3daae1
007f931
 
 
b3daae1
 
 
e9688d3
007f931
e9688d3
b3daae1
007f931
b3daae1
 
 
e9688d3
 
b3daae1
6f85038
b3daae1
007f931
 
0b445a6
b3daae1
9208e24
fca51d8
c6e2d82
 
625c5f1
c6e2d82
fca51d8
c6e2d82
007f931
547ce4f
b3daae1
 
 
007f931
c6e2d82
 
fca51d8
c6e2d82
 
431e7f9
 
b3daae1
fca51d8
431e7f9
 
 
 
 
b3daae1
 
 
 
 
 
 
fca51d8
e9688d3
431e7f9
 
fca51d8
431e7f9
 
b3daae1
75fed13
 
 
 
547ce4f
 
 
 
 
0b445a6
b3daae1
 
 
547ce4f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
import logging
import time
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from llama_cpp import Llama
import asyncio
import uvicorn

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set up cache directory
CACHE_DIR = "/app/.cache/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR

# Create the FastAPI app
app = FastAPI(
    title="MGZON Smart Assistant",
    description="دمج نموذج T5 المدرب مع Mistral-7B (GGUF) داخل Space"
)

# Initialize model variables
t5_tokenizer = None
t5_model = None
mistral = None
t5_loaded = False
mistral_loaded = False

# Root endpoint
@app.get("/")
async def root():
    logger.info(f"Root endpoint called at {time.time()}")
    return JSONResponse(
        content={"message": "MGZON Smart Assistant is running"},
        headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
    )

# Health check endpoint
@app.get("/health")
async def health_check():
    logger.info(f"Health check endpoint called at {time.time()}")
    return JSONResponse(
        content={"status": "healthy" if t5_loaded else "loading"},
        headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
    )

# Async function to load T5 model
async def load_t5_model():
    global t5_tokenizer, t5_model, t5_loaded
    start_time = time.time()
    logger.info(f"Starting T5 model loading at {start_time}")
    try:
        T5_MODEL_PATH = os.path.join(CACHE_DIR, "models--MGZON--mgzon-flan-t5-base/snapshots")
        logger.info(f"Loading tokenizer for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
        t5_tokenizer = AutoTokenizer.from_pretrained(
            T5_MODEL_PATH,
            local_files_only=True,
            torch_dtype="float16"  # Reduce memory usage
        )
        logger.info(f"Successfully loaded tokenizer for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
        logger.info(f"Loading model for MGZON/mgzon-flan-t5-base from {T5_MODEL_PATH}")
        t5_model = AutoModelForSeq2SeqLM.from_pretrained(
            T5_MODEL_PATH,
            local_files_only=True,
            torch_dtype="float16"  # Reduce memory usage
        )
        logger.info(f"Successfully loaded model for MGZON/mgzon-flan-t5-base in {time.time() - start_time} seconds")
        t5_loaded = True
    except Exception as e:
        logger.error(f"Failed to load T5 model: {str(e)}", exc_info=True)
        t5_loaded = False
        raise RuntimeError(f"Failed to load T5 model: {str(e)}")
    finally:
        end_time = time.time()
        logger.info(f"T5 model loading completed in {end_time - start_time} seconds")

# Async function to load Mistral model
async def load_mistral_model():
    global mistral, mistral_loaded
    start_time = time.time()
    logger.info(f"Starting Mistral model loading at {start_time}")
    try:
        gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q2_K.gguf")
        if not os.path.exists(gguf_path):
            logger.error(f"Mistral GGUF file not found at {gguf_path}")
            raise RuntimeError(f"Mistral GGUF file not found at {gguf_path}")
        logger.info(f"Loading Mistral model from {gguf_path}")
        mistral = Llama(
            model_path=gguf_path,
            n_ctx=512,
            n_threads=1,
            n_batch=128,
            verbose=True
        )
        logger.info(f"Successfully loaded Mistral model from {gguf_path} in {time.time() - start_time} seconds")
        mistral_loaded = True
    except Exception as e:
        logger.error(f"Failed to load Mistral model: {str(e)}", exc_info=True)
        mistral_loaded = False
        raise RuntimeError(f"Failed to load Mistral model: {str(e)}")
    finally:
        end_time = time.time()
        logger.info(f"Mistral model loading completed in {end_time - start_time} seconds")

# Run T5 model loading in the background
@app.on_event("startup")
async def startup_event():
    logger.info(f"Startup event triggered at {time.time()}")
    asyncio.create_task(load_t5_model())  # Load only T5 at startup

# Define request schema
class AskRequest(BaseModel):
    question: str
    max_new_tokens: int = 150

# Endpoint: /ask
@app.post("/ask")
async def ask(req: AskRequest):
    logger.info(f"Received ask request: {req.question} at {time.time()}")
    if not t5_loaded:
        logger.error("T5 model not loaded yet")
        raise HTTPException(status_code=503, detail="T5 model is still loading, please try again later")
    
    q = req.question.strip()
    if not q:
        logger.error("Empty question received")
        raise HTTPException(status_code=400, detail="Empty question")

    try:
        if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
            # Use T5 model
            logger.info("Using MGZON-FLAN-T5 model")
            inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
            out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
            answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
            model_name = "MGZON-FLAN-T5"
        else:
            # Load Mistral model if not loaded
            if not mistral_loaded:
                logger.info("Mistral model not loaded, loading now...")
                await load_mistral_model()
                if not mistral_loaded:
                    raise HTTPException(status_code=503, detail="Failed to load Mistral model")
            # Use Mistral model
            logger.info("Using Mistral-7B-GGUF model")
            out = mistral(prompt=q, max_tokens=req.max_new_tokens, temperature=0.7)
            answer = out["choices"][0]["text"].strip()
            model_name = "Mistral-7B-GGUF"
        logger.info(f"Response generated by {model_name}: {answer}")
        return {"model": model_name, "response": answer}
    except Exception as e:
        logger.error(f"Error processing request: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")

# Run the app
if __name__ == "__main__":
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=8080,
        log_level="info",
        workers=1,
        timeout_keep_alive=15,
        limit_concurrency=5,
        limit_max_requests=50
    )