MGZON commited on
Commit
fca51d8
·
verified ·
1 Parent(s): 9ec6eb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -14
app.py CHANGED
@@ -1,71 +1,104 @@
1
  import os
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from llama_cpp import Llama
6
 
7
- # إعداد مسار الـ cache
 
 
 
 
8
  CACHE_DIR = os.environ.get("HF_HOME", "/app/.cache/huggingface")
9
 
10
- # تأكد من أن المكتبتين تقرأ المتغيّرات البيئية
11
  os.environ["HF_HOME"] = CACHE_DIR
12
 
13
- # إنشاء التطبيق
14
  app = FastAPI(
15
  title="MGZON Smart Assistant",
16
- description="دمج نموذج T5 المدرب مع Mistral7B (GGUF) داخل Space"
17
  )
18
 
19
- # تحميل نموذج T5 المدرب من Hub
 
 
 
 
 
20
  T5_REPO = "MGZON/mgzon-flan-t5-base"
21
  try:
22
- t5_tokenizer = AutoTokenizer.from_pretrained(T5_REPO, cache_dir=CACHE_DIR)
23
- t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_REPO, cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
24
  except Exception as e:
25
- raise RuntimeError(f"فشل تحميل نموذج T5 من {T5_REPO}: {str(e)}")
 
26
 
27
- # تحميل ملف Mistral .gguf
28
  gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q4_K_M.gguf")
29
  if not os.path.exists(gguf_path):
 
30
  raise RuntimeError(
31
- f"ملف Mistral .gguf غير موجود في {gguf_path}. "
32
  "تأكد من أن ملف setup.sh تم تنفيذه أثناء الـ build."
33
  )
34
 
35
  try:
 
36
  mistral = Llama(
37
  model_path=gguf_path,
38
  n_ctx=2048,
39
  n_threads=8,
40
  # إذا كان لديك GPU، يمكنك إضافة: n_gpu_layers=35
41
  )
 
42
  except Exception as e:
43
- raise RuntimeError(f"فشل تحميل نموذج Mistral من {gguf_path}: {str(e)}")
 
44
 
45
- # تعريف شكل الطلب (JSON)
46
  class AskRequest(BaseModel):
47
  question: str
48
  max_new_tokens: int = 150
49
 
50
- # نقطة النهاية /ask
51
  @app.post("/ask")
52
  def ask(req: AskRequest):
 
53
  q = req.question.strip()
54
  if not q:
 
55
  raise HTTPException(status_code=400, detail="Empty question")
56
 
57
  try:
58
  if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
59
  # نموذج T5
 
60
  inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
61
  out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
62
  answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
63
  model_name = "MGZON-FLAN-T5"
64
  else:
65
  # نموذج Mistral
 
66
  out = mistral(prompt=q, max_tokens=req.max_new_tokens)
67
  answer = out["choices"][0]["text"].strip()
68
  model_name = "Mistral-7B-GGUF"
 
69
  return {"model": model_name, "response": answer}
70
  except Exception as e:
71
- raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")
 
 
1
  import os
2
+ import logging
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  from llama_cpp import Llama
7
 
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Set up cache directory
13
  CACHE_DIR = os.environ.get("HF_HOME", "/app/.cache/huggingface")
14
 
15
+ # Ensure libraries use the cache directory
16
  os.environ["HF_HOME"] = CACHE_DIR
17
 
18
+ # Create the FastAPI app
19
  app = FastAPI(
20
  title="MGZON Smart Assistant",
21
+ description="دمج نموذج T5 المدرب مع Mistral-7B (GGUF) داخل Space"
22
  )
23
 
24
+ # Health check endpoint
25
+ @app.get("/health")
26
+ async def health_check():
27
+ return {"status": "healthy"}
28
+
29
+ # Load T5 model from Hub
30
  T5_REPO = "MGZON/mgzon-flan-t5-base"
31
  try:
32
+ logger.info(f"Loading tokenizer for {T5_REPO} with HF_TOKEN")
33
+ t5_tokenizer = AutoTokenizer.from_pretrained(
34
+ T5_REPO,
35
+ cache_dir=CACHE_DIR,
36
+ use_auth_token=os.environ.get("HF_TOKEN")
37
+ )
38
+ logger.info(f"Successfully loaded tokenizer for {T5_REPO}")
39
+ logger.info(f"Loading model for {T5_REPO}")
40
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained(
41
+ T5_REPO,
42
+ cache_dir=CACHE_DIR,
43
+ use_auth_token=os.environ.get("HF_TOKEN")
44
+ )
45
+ logger.info(f"Successfully loaded model for {T5_REPO}")
46
  except Exception as e:
47
+ logger.error(f"Failed to load T5 model from {T5_REPO}: {str(e)}")
48
+ raise RuntimeError(f"Failed to load T5 model from {T5_REPO}: {str(e)}")
49
 
50
+ # Load Mistral GGUF model
51
  gguf_path = os.path.abspath("models/mistral-7b-instruct-v0.1.Q4_K_M.gguf")
52
  if not os.path.exists(gguf_path):
53
+ logger.error(f"Mistral GGUF file not found at {gguf_path}")
54
  raise RuntimeError(
55
+ f"Mistral GGUF file not found at {gguf_path}. "
56
  "تأكد من أن ملف setup.sh تم تنفيذه أثناء الـ build."
57
  )
58
 
59
  try:
60
+ logger.info(f"Loading Mistral model from {gguf_path}")
61
  mistral = Llama(
62
  model_path=gguf_path,
63
  n_ctx=2048,
64
  n_threads=8,
65
  # إذا كان لديك GPU، يمكنك إضافة: n_gpu_layers=35
66
  )
67
+ logger.info(f"Successfully loaded Mistral model from {gguf_path}")
68
  except Exception as e:
69
+ logger.error(f"Failed to load Mistral model from {gguf_path}: {str(e)}")
70
+ raise RuntimeError(f"Failed to load Mistral model from {gguf_path}: {str(e)}")
71
 
72
+ # Define request schema
73
  class AskRequest(BaseModel):
74
  question: str
75
  max_new_tokens: int = 150
76
 
77
+ # Endpoint: /ask
78
  @app.post("/ask")
79
  def ask(req: AskRequest):
80
+ logger.info(f"Received question: {req.question}")
81
  q = req.question.strip()
82
  if not q:
83
+ logger.error("Empty question received")
84
  raise HTTPException(status_code=400, detail="Empty question")
85
 
86
  try:
87
  if any(tok in q.lower() for tok in ["mgzon", "flan", "t5"]):
88
  # نموذج T5
89
+ logger.info("Using MGZON-FLAN-T5 model")
90
  inputs = t5_tokenizer(q, return_tensors="pt", truncation=True, max_length=256)
91
  out_ids = t5_model.generate(**inputs, max_length=req.max_new_tokens)
92
  answer = t5_tokenizer.decode(out_ids[0], skip_special_tokens=True)
93
  model_name = "MGZON-FLAN-T5"
94
  else:
95
  # نموذج Mistral
96
+ logger.info("Using Mistral-7B-GGUF model")
97
  out = mistral(prompt=q, max_tokens=req.max_new_tokens)
98
  answer = out["choices"][0]["text"].strip()
99
  model_name = "Mistral-7B-GGUF"
100
+ logger.info(f"Response generated by {model_name}: {answer}")
101
  return {"model": model_name, "response": answer}
102
  except Exception as e:
103
+ logger.error(f"Error processing request: {str(e)}")
104
+ raise HTTPException(status_code=500, detail=f"خطأ أثناء معالجة الطلب: {str(e)}")