CooLLaMACEO commited on
Commit
f895e5d
·
verified ·
1 Parent(s): 01cf8d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -8,13 +8,16 @@ from fastapi import FastAPI, HTTPException, Depends
8
  from fastapi.security.api_key import APIKeyHeader
9
  from pydantic import BaseModel
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
11
- from starlette.status import HTTP_403_FORBIDDEN
12
 
13
  # --- CONFIGURATION ---
14
  MODEL_PATH = "/app/model"
15
  API_KEY_NAME = "X-API-Key"
16
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
17
 
 
 
 
18
  generated_keys = {}
19
 
20
  app = FastAPI(title="Overflow-111.7B Self-Registering API")
@@ -28,31 +31,29 @@ try:
28
  sys.path.insert(0, MODEL_PATH)
29
 
30
  # 2. FORCE-REGISTER CONFIGURATION
31
- # We manually import the file to ensure the class is loaded into memory
32
  import configuration_overflow
33
- # Use the exact class name you provided
34
  conf_class = configuration_overflow.OverflowConfig
35
  AutoConfig.register("overflow", conf_class)
36
  print(f"Successfully registered 'overflow' config.")
37
 
38
  # 3. FORCE-REGISTER MODEL
39
- # We need to find the model class in modeling_overflow.py
40
  import modeling_overflow
41
- # Search for the class that ends with 'ForCausalLM'
42
  model_classes = [c for c in dir(modeling_overflow) if 'ForCausalLM' in c]
43
  if model_classes:
44
  model_class = getattr(modeling_overflow, model_classes[0])
45
  AutoModelForCausalLM.register(conf_class, model_class)
46
  print(f"Successfully registered {model_classes[0]} to AutoModel.")
47
 
48
- # 4. LOAD TOKENIZER
 
 
 
49
  print("Loading Tokenizer...")
50
  tokenizer = AutoTokenizer.from_pretrained(
51
  MODEL_PATH,
52
  trust_remote_code=True
53
  )
54
 
55
- # 5. LOAD MODEL (CPU Optimized)
56
  print("Loading Model Weights (111.7B Parameters - 1-Bit)...")
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL_PATH,
@@ -87,6 +88,13 @@ async def verify_auth(api_key: str = Depends(api_key_header)):
87
  # --- ENDPOINTS ---
88
  @app.post("/v1/generate")
89
  async def generate(query: Query, auth: str = Depends(verify_auth)):
 
 
 
 
 
 
 
90
  try:
91
  inputs = tokenizer(query.prompt, return_tensors="pt")
92
  with torch.no_grad():
@@ -103,7 +111,8 @@ async def generate(query: Query, auth: str = Depends(verify_auth)):
103
 
104
  @app.get("/")
105
  def health():
106
- return {"status": "active", "engine": "Overflow-111.7B"}
 
107
 
108
  if __name__ == "__main__":
109
  import uvicorn
 
8
  from fastapi.security.api_key import APIKeyHeader
9
  from pydantic import BaseModel
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
11
+ from starlette.status import HTTP_403_FORBIDDEN, HTTP_503_SERVICE_UNAVAILABLE
12
 
13
  # --- CONFIGURATION ---
14
  MODEL_PATH = "/app/model"
15
  API_KEY_NAME = "X-API-Key"
16
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
17
 
18
+ # Initialize variables at the global level so the functions can see them
19
+ tokenizer = None
20
+ model = None
21
  generated_keys = {}
22
 
23
  app = FastAPI(title="Overflow-111.7B Self-Registering API")
 
31
  sys.path.insert(0, MODEL_PATH)
32
 
33
  # 2. FORCE-REGISTER CONFIGURATION
 
34
  import configuration_overflow
 
35
  conf_class = configuration_overflow.OverflowConfig
36
  AutoConfig.register("overflow", conf_class)
37
  print(f"Successfully registered 'overflow' config.")
38
 
39
  # 3. FORCE-REGISTER MODEL
 
40
  import modeling_overflow
 
41
  model_classes = [c for c in dir(modeling_overflow) if 'ForCausalLM' in c]
42
  if model_classes:
43
  model_class = getattr(modeling_overflow, model_classes[0])
44
  AutoModelForCausalLM.register(conf_class, model_class)
45
  print(f"Successfully registered {model_classes[0]} to AutoModel.")
46
 
47
+ # 4. LOAD TOKENIZER & MODEL
48
+ # We use 'global' to update the variables we defined at the top
49
+ global tokenizer, model
50
+
51
  print("Loading Tokenizer...")
52
  tokenizer = AutoTokenizer.from_pretrained(
53
  MODEL_PATH,
54
  trust_remote_code=True
55
  )
56
 
 
57
  print("Loading Model Weights (111.7B Parameters - 1-Bit)...")
58
  model = AutoModelForCausalLM.from_pretrained(
59
  MODEL_PATH,
 
88
  # --- ENDPOINTS ---
89
  @app.post("/v1/generate")
90
  async def generate(query: Query, auth: str = Depends(verify_auth)):
91
+ # Safety Check: If the model hasn't finished loading yet
92
+ if tokenizer is None or model is None:
93
+ raise HTTPException(
94
+ status_code=HTTP_503_SERVICE_UNAVAILABLE,
95
+ detail="Engine is still booting up. Please wait a moment."
96
+ )
97
+
98
  try:
99
  inputs = tokenizer(query.prompt, return_tensors="pt")
100
  with torch.no_grad():
 
111
 
112
  @app.get("/")
113
  def health():
114
+ state = "active" if model else "loading"
115
+ return {"status": state, "engine": "Overflow-111.7B"}
116
 
117
  if __name__ == "__main__":
118
  import uvicorn