CooLLaMACEO commited on
Commit
6a82d80
·
verified ·
1 Parent(s): fdb020b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -10,49 +10,54 @@ from pydantic import BaseModel
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
11
  from starlette.status import HTTP_403_FORBIDDEN, HTTP_503_SERVICE_UNAVAILABLE
12
 
13
- # --- CONFIGURATION ---
14
- MODEL_PATH = "/app/model"
15
- API_KEY_NAME = "X-API-Key"
16
- api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
17
-
18
- # 1. Initialize at top level
19
  tokenizer = None
20
  model = None
21
  generated_keys = {}
22
 
 
 
 
 
 
23
  app = FastAPI(title="Overflow-111.7B Self-Registering API")
24
 
25
- # --- ENGINE LOADING ---
26
  print("Starting Engine: Initializing Self-Registration...")
27
 
28
  try:
29
- # 2. MUST declare global at the very beginning of the block
30
  global tokenizer, model
31
 
 
32
  if MODEL_PATH not in sys.path:
33
  sys.path.insert(0, MODEL_PATH)
34
 
35
- # Force-Register Config
36
  import configuration_overflow
37
  conf_class = configuration_overflow.OverflowConfig
38
  AutoConfig.register("overflow", conf_class)
39
- print(f"Successfully registered 'overflow' config.")
40
 
41
- # Force-Register Model
42
  import modeling_overflow
 
43
  model_classes = [c for c in dir(modeling_overflow) if 'ForCausalLM' in c]
44
  if model_classes:
45
  model_class = getattr(modeling_overflow, model_classes[0])
46
  AutoModelForCausalLM.register(conf_class, model_class)
47
  print(f"Successfully registered {model_classes[0]} to AutoModel.")
48
 
49
- # 3. Now load them into the global variables
50
  print("Loading Tokenizer...")
51
  tokenizer = AutoTokenizer.from_pretrained(
52
  MODEL_PATH,
53
  trust_remote_code=True
54
  )
55
 
 
 
56
  print("Loading Model Weights (111.7B Parameters - 1-Bit)...")
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL_PATH,
@@ -66,31 +71,34 @@ try:
66
  except Exception as e:
67
  print(f"CRITICAL LOADING ERROR: {e}")
68
 
69
- # --- SCHEMAS ---
70
  class Query(BaseModel):
71
  prompt: str
72
  max_tokens: int = 50
73
  temperature: float = 0.7
74
 
75
- # --- AUTH ---
76
  @app.get("/api/generate")
77
  async def create_new_key():
 
78
  new_key = f"of_sk-{secrets.token_hex(12)}"
79
  generated_keys[new_key] = {"created_at": time.time()}
80
  return {"status": "success", "api_key": new_key}
81
 
82
  async def verify_auth(api_key: str = Depends(api_key_header)):
 
83
  if api_key in generated_keys or api_key == os.environ.get("MASTER_API_KEY"):
84
  return api_key
85
  raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key")
86
 
87
- # --- ENDPOINTS ---
88
  @app.post("/v1/generate")
89
  async def generate(query: Query, auth: str = Depends(verify_auth)):
 
90
  if tokenizer is None or model is None:
91
  raise HTTPException(
92
  status_code=HTTP_503_SERVICE_UNAVAILABLE,
93
- detail="Engine is still booting up. Please wait."
94
  )
95
 
96
  try:
@@ -103,11 +111,18 @@ async def generate(query: Query, auth: str = Depends(verify_auth)):
103
  do_sample=True if query.temperature > 0 else False
104
  )
105
  response_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
106
- return {"model": "Overflow-111.7B", "choices": [{"text": response_text}]}
 
 
 
107
  except Exception as e:
108
  raise HTTPException(status_code=500, detail=str(e))
109
 
110
  @app.get("/")
111
  def health():
112
  state = "active" if model else "loading"
113
- return {"status": state, "engine": "Overflow-111.7B"}
 
 
 
 
 
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
11
  from starlette.status import HTTP_403_FORBIDDEN, HTTP_503_SERVICE_UNAVAILABLE
12
 
13
+ # --- 1. GLOBAL INITIALIZATION ---
14
+ # We define these at the top level so they exist when the app starts.
 
 
 
 
15
  tokenizer = None
16
  model = None
17
  generated_keys = {}
18
 
19
+ # --- 2. CONFIGURATION ---
20
+ MODEL_PATH = "/app/model"
21
+ API_KEY_NAME = "X-API-Key"
22
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
23
+
24
  app = FastAPI(title="Overflow-111.7B Self-Registering API")
25
 
26
+ # --- 3. ENGINE LOADING & SELF-REGISTRATION ---
27
  print("Starting Engine: Initializing Self-Registration...")
28
 
29
  try:
30
+ # IMPORTANT: Global declaration must come BEFORE any usage in this block
31
  global tokenizer, model
32
 
33
+ # Add model path to system so Python finds configuration_overflow.py
34
  if MODEL_PATH not in sys.path:
35
  sys.path.insert(0, MODEL_PATH)
36
 
37
+ # Force-Register the Custom Configuration
38
  import configuration_overflow
39
  conf_class = configuration_overflow.OverflowConfig
40
  AutoConfig.register("overflow", conf_class)
41
+ print("Successfully registered 'overflow' config.")
42
 
43
+ # Force-Register the Custom Model Architecture
44
  import modeling_overflow
45
+ # Dynamically find the CausalLM class (usually OverflowForCausalLM)
46
  model_classes = [c for c in dir(modeling_overflow) if 'ForCausalLM' in c]
47
  if model_classes:
48
  model_class = getattr(modeling_overflow, model_classes[0])
49
  AutoModelForCausalLM.register(conf_class, model_class)
50
  print(f"Successfully registered {model_classes[0]} to AutoModel.")
51
 
52
+ # Load Tokenizer
53
  print("Loading Tokenizer...")
54
  tokenizer = AutoTokenizer.from_pretrained(
55
  MODEL_PATH,
56
  trust_remote_code=True
57
  )
58
 
59
+ # Load Model Weights
60
+ # Optimized for CPU usage with bfloat16 and low memory footprint
61
  print("Loading Model Weights (111.7B Parameters - 1-Bit)...")
62
  model = AutoModelForCausalLM.from_pretrained(
63
  MODEL_PATH,
 
71
  except Exception as e:
72
  print(f"CRITICAL LOADING ERROR: {e}")
73
 
74
+ # --- 4. API SCHEMAS ---
75
  class Query(BaseModel):
76
  prompt: str
77
  max_tokens: int = 50
78
  temperature: float = 0.7
79
 
80
+ # --- 5. AUTHENTICATION ---
81
  @app.get("/api/generate")
82
  async def create_new_key():
83
+ """Generates a unique of_sk- key for the session."""
84
  new_key = f"of_sk-{secrets.token_hex(12)}"
85
  generated_keys[new_key] = {"created_at": time.time()}
86
  return {"status": "success", "api_key": new_key}
87
 
88
  async def verify_auth(api_key: str = Depends(api_key_header)):
89
+ """Validates the X-API-Key header."""
90
  if api_key in generated_keys or api_key == os.environ.get("MASTER_API_KEY"):
91
  return api_key
92
  raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key")
93
 
94
+ # --- 6. CORE ENDPOINTS ---
95
  @app.post("/v1/generate")
96
  async def generate(query: Query, auth: str = Depends(verify_auth)):
97
+ # If a user pings the API before the 111.7B weights are in RAM
98
  if tokenizer is None or model is None:
99
  raise HTTPException(
100
  status_code=HTTP_503_SERVICE_UNAVAILABLE,
101
+ detail="Engine is still booting up (111.7B parameters take time). Please wait."
102
  )
103
 
104
  try:
 
111
  do_sample=True if query.temperature > 0 else False
112
  )
113
  response_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
114
+ return {
115
+ "model": "Overflow-111.7B",
116
+ "choices": [{"text": response_text}]
117
+ }
118
  except Exception as e:
119
  raise HTTPException(status_code=500, detail=str(e))
120
 
121
  @app.get("/")
122
  def health():
123
  state = "active" if model else "loading"
124
+ return {"status": state, "engine": "Overflow-111.7B"}
125
+
126
+ if __name__ == "__main__":
127
+ import uvicorn
128
+ uvicorn.run(app, host="0.0.0.0", port=7860)