danmac1 commited on
Commit
3bd8666
·
verified ·
1 Parent(s): db4b4c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -84
app.py CHANGED
@@ -5,10 +5,12 @@ from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  import uvicorn
7
  import os
 
8
 
9
  # --- Global Variables for Model and Tokenizer ---
10
  model = None
11
  tokenizer = None
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  print(f"--- Initializing on Device: {device} ---")
14
 
@@ -24,104 +26,130 @@ class PromptRequest(BaseModel):
24
  app = FastAPI()
25
 
26
  def load_model_and_tokenizer():
27
- global model, tokenizer
28
 
29
  base_model_id = os.environ.get("BASE_MODEL_ID")
30
  adapter_path = os.environ.get("ADAPTER_PATH")
31
  hf_token = os.environ.get("HF_TOKEN")
32
 
33
  if not base_model_id:
34
- print("ERROR: BASE_MODEL_ID environment variable not set.")
35
- raise ValueError("BASE_MODEL_ID environment variable not set.")
 
36
  if not adapter_path:
37
- print("ERROR: ADAPTER_PATH environment variable not set.")
38
- raise ValueError("ADAPTER_PATH environment variable not set.")
39
 
40
  print(f"Using device: {device}")
41
  print(f"Attempting to load base model: {base_model_id}")
42
  print(f"Attempting to load adapter from: {adapter_path}")
43
 
44
- # --- Load Tokenizer ---
45
- print(f"Loading tokenizer...")
46
  try:
47
- tokenizer = AutoTokenizer.from_pretrained(adapter_path, token=hf_token, trust_remote_code=True)
48
- print(f"Loaded tokenizer from adapter path: {adapter_path}")
49
- except Exception as e:
50
- print(f"Could not load tokenizer from adapter path: {e}. Loading from base model path: {base_model_id}")
51
- tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)
52
-
53
- if tokenizer.pad_token is None:
54
- if tokenizer.eos_token is not None:
55
- print("Setting pad_token to eos_token.")
56
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  else:
58
- print("Adding new pad_token '[PAD]'.")
59
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
60
- tokenizer.padding_side = "left"
61
-
62
- # --- Configure Quantization ---
63
- print("Configuring 4-bit quantization...")
64
- compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() and device == "cuda" else torch.float16
65
-
66
- bnb_config = None
67
- if device == "cuda":
68
- bnb_config = BitsAndBytesConfig(
69
- load_in_4bit=True,
70
- bnb_4bit_quant_type="nf4",
71
- bnb_4bit_compute_dtype=compute_dtype,
72
- bnb_4bit_use_double_quant=True,
 
 
73
  )
74
- print(f"Using BNB config with compute_dtype: {compute_dtype}")
75
- else:
76
- print("Running on CPU, BNB quantization will not be applied.")
77
-
78
- # --- Load Base Model with Quantization ---
79
- print(f"Loading base model: {base_model_id}...")
80
- config = AutoConfig.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)
81
- if getattr(config, "pretraining_tp", 1) != 1:
82
- print(f"Overriding pretraining_tp from {getattr(config, 'pretraining_tp', 'N/A')} to 1.")
83
- config.pretraining_tp = 1
84
-
85
- base_model_instance = AutoModelForCausalLM.from_pretrained(
86
- base_model_id,
87
- config=config,
88
- quantization_config=bnb_config if device == "cuda" else None,
89
- device_map={"": device},
90
- token=hf_token,
91
- trust_remote_code=True,
92
- low_cpu_mem_usage=True if device == "cuda" else False
93
- )
94
- print("Base model loaded.")
95
-
96
- if tokenizer.pad_token_id is not None and tokenizer.pad_token_id >= base_model_instance.config.vocab_size:
97
- print("Resizing token embeddings for base model.")
98
- base_model_instance.resize_token_embeddings(len(tokenizer))
99
-
100
- # --- Load LoRA Adapter ---
101
- print(f"Loading LoRA adapter from: {adapter_path}...")
102
- model = PeftModel.from_pretrained(base_model_instance, adapter_path)
103
- model.eval()
104
- print("LoRA adapter loaded and model is in eval mode.")
105
- print(f"Model is on device: {model.device}")
106
 
107
  @app.on_event("startup")
108
  async def startup_event():
109
- print("Server startup event: Loading model and tokenizer...")
110
- try:
111
- load_model_and_tokenizer()
112
- print("Model and tokenizer loaded successfully via startup event.")
113
- except Exception as e:
114
- print(f"CRITICAL ERROR during startup model loading: {e}")
115
- # This error might not stop Uvicorn if it's already started by __main__
116
- # but it will prevent the /generate endpoint from working.
117
- # Consider raising an exception here to potentially stop the app if model load fails.
118
- # For now, it will print and the /generate endpoint will show model not loaded.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  @app.post("/generate/")
121
  async def generate_text(request: PromptRequest):
122
- global model, tokenizer
123
- if model is None or tokenizer is None:
124
- # This error will be returned to the client
125
  raise HTTPException(status_code=503, detail="Model is not loaded or still loading. Please try again shortly or check server logs.")
126
 
127
  try:
@@ -156,16 +184,13 @@ async def generate_text(request: PromptRequest):
156
  raise HTTPException(status_code=500, detail=str(e))
157
 
158
  if __name__ == "__main__":
159
- print("Starting Uvicorn server directly from app.py...")
160
- # Hugging Face Spaces injects the PORT environment variable.
161
- # Default to 8000 if not set (for local testing without Spaces).
162
  port = int(os.environ.get("PORT", 8000))
163
- host = "0.0.0.0" # Listen on all available network interfaces
164
-
165
  print(f"Uvicorn will attempt to listen on host {host}, port {port}")
 
166
 
167
- # The @app.on_event("startup") should be called by Uvicorn when it starts the app.
168
- # This will trigger load_model_and_tokenizer().
169
  try:
170
  uvicorn.run(app, host=host, port=port)
171
  except Exception as e:
 
5
  from pydantic import BaseModel
6
  import uvicorn
7
  import os
8
+ import time # For checking model load status
9
 
10
  # --- Global Variables for Model and Tokenizer ---
11
  model = None
12
  tokenizer = None
13
+ model_loaded_successfully = False # Flag to indicate model status
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(f"--- Initializing on Device: {device} ---")
16
 
 
26
  app = FastAPI()
27
 
28
  def load_model_and_tokenizer():
29
+ global model, tokenizer, model_loaded_successfully
30
 
31
  base_model_id = os.environ.get("BASE_MODEL_ID")
32
  adapter_path = os.environ.get("ADAPTER_PATH")
33
  hf_token = os.environ.get("HF_TOKEN")
34
 
35
  if not base_model_id:
36
+ print("CRITICAL ERROR: BASE_MODEL_ID environment variable not set.")
37
+ # In a real app, you might want to prevent startup or handle this more gracefully
38
+ return
39
  if not adapter_path:
40
+ print("CRITICAL ERROR: ADAPTER_PATH environment variable not set.")
41
+ return
42
 
43
  print(f"Using device: {device}")
44
  print(f"Attempting to load base model: {base_model_id}")
45
  print(f"Attempting to load adapter from: {adapter_path}")
46
 
 
 
47
  try:
48
+ # --- Load Tokenizer ---
49
+ print(f"Loading tokenizer...")
50
+ try:
51
+ tokenizer = AutoTokenizer.from_pretrained(adapter_path, token=hf_token, trust_remote_code=True)
52
+ print(f"Loaded tokenizer from adapter path: {adapter_path}")
53
+ except Exception as e:
54
+ print(f"Could not load tokenizer from adapter path: {e}. Loading from base model path: {base_model_id}")
55
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)
56
+
57
+ if tokenizer.pad_token is None:
58
+ if tokenizer.eos_token is not None:
59
+ print("Setting pad_token to eos_token.")
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+ else:
62
+ print("Adding new pad_token '[PAD]'.")
63
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
64
+ tokenizer.padding_side = "left"
65
+
66
+ # --- Configure Quantization ---
67
+ print("Configuring 4-bit quantization...")
68
+ compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() and device == "cuda" else torch.float16
69
+
70
+ bnb_config = None
71
+ if device == "cuda":
72
+ bnb_config = BitsAndBytesConfig(
73
+ load_in_4bit=True,
74
+ bnb_4bit_quant_type="nf4",
75
+ bnb_4bit_compute_dtype=compute_dtype,
76
+ bnb_4bit_use_double_quant=True,
77
+ )
78
+ print(f"Using BNB config with compute_dtype: {compute_dtype}")
79
  else:
80
+ print("Running on CPU, BNB quantization will not be applied.")
81
+
82
+ # --- Load Base Model with Quantization ---
83
+ print(f"Loading base model: {base_model_id}...")
84
+ config = AutoConfig.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)
85
+ if getattr(config, "pretraining_tp", 1) != 1:
86
+ print(f"Overriding pretraining_tp from {getattr(config, 'pretraining_tp', 'N/A')} to 1.")
87
+ config.pretraining_tp = 1
88
+
89
+ base_model_instance = AutoModelForCausalLM.from_pretrained(
90
+ base_model_id,
91
+ config=config,
92
+ quantization_config=bnb_config if device == "cuda" else None,
93
+ device_map={"": device},
94
+ token=hf_token,
95
+ trust_remote_code=True,
96
+ low_cpu_mem_usage=True if device == "cuda" else False
97
  )
98
+ print("Base model loaded.")
99
+
100
+ if tokenizer.pad_token_id is not None and tokenizer.pad_token_id >= base_model_instance.config.vocab_size:
101
+ print("Resizing token embeddings for base model.")
102
+ base_model_instance.resize_token_embeddings(len(tokenizer))
103
+
104
+ # --- Load LoRA Adapter ---
105
+ print(f"Loading LoRA adapter from: {adapter_path}...")
106
+ model = PeftModel.from_pretrained(base_model_instance, adapter_path)
107
+ model.eval()
108
+ print("LoRA adapter loaded and model is in eval mode.")
109
+ print(f"Model is on device: {model.device}")
110
+ model_loaded_successfully = True # Set flag on successful load
111
+ print("Model and tokenizer loaded successfully.")
112
+
113
+ except Exception as e:
114
+ print(f"CRITICAL ERROR during model/tokenizer loading: {e}")
115
+ model_loaded_successfully = False
116
+ # Optionally, re-raise or handle to prevent app from starting if model load fails.
117
+ # For now, it will print error and the /generate endpoint will show model not loaded.
118
+ # And the health check will show model not ready.
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  @app.on_event("startup")
121
  async def startup_event():
122
+ print("Server startup event: Initiating model and tokenizer loading...")
123
+ # Model loading can take time, so it's done here.
124
+ # Health checks might hit the server before this completes.
125
+ load_model_and_tokenizer()
126
+ if model_loaded_successfully:
127
+ print("Model loading process completed successfully within startup event.")
128
+ else:
129
+ print("Model loading process encountered an error or did not complete within startup event.")
130
+
131
+
132
+ # <<< --- ADDED HEALTH CHECK ENDPOINT --- >>>
133
+ @app.get("/")
134
+ async def health_check():
135
+ """Basic health check endpoint."""
136
+ if model_loaded_successfully and model is not None and tokenizer is not None:
137
+ return {"status": "ok", "message": "Model is loaded and ready."}
138
+ else:
139
+ # Return a 503 if model isn't ready yet, so Spaces knows it's still starting up
140
+ # or if loading failed.
141
+ raise HTTPException(status_code=503, detail="Model is not loaded or still loading.")
142
+
143
+ @app.get("/health") # Common alternative health check path
144
+ async def health_check_alternative():
145
+ return await health_check()
146
+ # <<< --- END OF HEALTH CHECK ENDPOINT --- >>>
147
+
148
 
149
  @app.post("/generate/")
150
  async def generate_text(request: PromptRequest):
151
+ global model, tokenizer, model_loaded_successfully
152
+ if not model_loaded_successfully or model is None or tokenizer is None:
 
153
  raise HTTPException(status_code=503, detail="Model is not loaded or still loading. Please try again shortly or check server logs.")
154
 
155
  try:
 
184
  raise HTTPException(status_code=500, detail=str(e))
185
 
186
  if __name__ == "__main__":
187
+ print("Starting Uvicorn server directly from app.py for local testing...")
 
 
188
  port = int(os.environ.get("PORT", 8000))
189
+ host = "0.0.0.0"
 
190
  print(f"Uvicorn will attempt to listen on host {host}, port {port}")
191
+ print("Set BASE_MODEL_ID and ADAPTER_PATH environment variables for model loading.")
192
 
193
+ # The @app.on_event("startup") will be called by Uvicorn.
 
194
  try:
195
  uvicorn.run(app, host=host, port=port)
196
  except Exception as e: