danmac1's picture
Update app.py
3bd8666 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
from peft import PeftModel
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import os
import time # For checking model load status
# --- Global Variables for Model and Tokenizer ---
model = None
tokenizer = None
model_loaded_successfully = False # Flag to indicate model status
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--- Initializing on Device: {device} ---")
# --- Pydantic Model for Request Body ---
class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.9
top_k: int = 50
# --- FastAPI App Initialization ---
app = FastAPI()
def load_model_and_tokenizer():
global model, tokenizer, model_loaded_successfully
base_model_id = os.environ.get("BASE_MODEL_ID")
adapter_path = os.environ.get("ADAPTER_PATH")
hf_token = os.environ.get("HF_TOKEN")
if not base_model_id:
print("CRITICAL ERROR: BASE_MODEL_ID environment variable not set.")
# In a real app, you might want to prevent startup or handle this more gracefully
return
if not adapter_path:
print("CRITICAL ERROR: ADAPTER_PATH environment variable not set.")
return
print(f"Using device: {device}")
print(f"Attempting to load base model: {base_model_id}")
print(f"Attempting to load adapter from: {adapter_path}")
try:
# --- Load Tokenizer ---
print(f"Loading tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(adapter_path, token=hf_token, trust_remote_code=True)
print(f"Loaded tokenizer from adapter path: {adapter_path}")
except Exception as e:
print(f"Could not load tokenizer from adapter path: {e}. Loading from base model path: {base_model_id}")
tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
print("Setting pad_token to eos_token.")
tokenizer.pad_token = tokenizer.eos_token
else:
print("Adding new pad_token '[PAD]'.")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.padding_side = "left"
# --- Configure Quantization ---
print("Configuring 4-bit quantization...")
compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() and device == "cuda" else torch.float16
bnb_config = None
if device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
print(f"Using BNB config with compute_dtype: {compute_dtype}")
else:
print("Running on CPU, BNB quantization will not be applied.")
# --- Load Base Model with Quantization ---
print(f"Loading base model: {base_model_id}...")
config = AutoConfig.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)
if getattr(config, "pretraining_tp", 1) != 1:
print(f"Overriding pretraining_tp from {getattr(config, 'pretraining_tp', 'N/A')} to 1.")
config.pretraining_tp = 1
base_model_instance = AutoModelForCausalLM.from_pretrained(
base_model_id,
config=config,
quantization_config=bnb_config if device == "cuda" else None,
device_map={"": device},
token=hf_token,
trust_remote_code=True,
low_cpu_mem_usage=True if device == "cuda" else False
)
print("Base model loaded.")
if tokenizer.pad_token_id is not None and tokenizer.pad_token_id >= base_model_instance.config.vocab_size:
print("Resizing token embeddings for base model.")
base_model_instance.resize_token_embeddings(len(tokenizer))
# --- Load LoRA Adapter ---
print(f"Loading LoRA adapter from: {adapter_path}...")
model = PeftModel.from_pretrained(base_model_instance, adapter_path)
model.eval()
print("LoRA adapter loaded and model is in eval mode.")
print(f"Model is on device: {model.device}")
model_loaded_successfully = True # Set flag on successful load
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"CRITICAL ERROR during model/tokenizer loading: {e}")
model_loaded_successfully = False
# Optionally, re-raise or handle to prevent app from starting if model load fails.
# For now, it will print error and the /generate endpoint will show model not loaded.
# And the health check will show model not ready.
@app.on_event("startup")
async def startup_event():
print("Server startup event: Initiating model and tokenizer loading...")
# Model loading can take time, so it's done here.
# Health checks might hit the server before this completes.
load_model_and_tokenizer()
if model_loaded_successfully:
print("Model loading process completed successfully within startup event.")
else:
print("Model loading process encountered an error or did not complete within startup event.")
# <<< --- ADDED HEALTH CHECK ENDPOINT --- >>>
@app.get("/")
async def health_check():
"""Basic health check endpoint."""
if model_loaded_successfully and model is not None and tokenizer is not None:
return {"status": "ok", "message": "Model is loaded and ready."}
else:
# Return a 503 if model isn't ready yet, so Spaces knows it's still starting up
# or if loading failed.
raise HTTPException(status_code=503, detail="Model is not loaded or still loading.")
@app.get("/health") # Common alternative health check path
async def health_check_alternative():
return await health_check()
# <<< --- END OF HEALTH CHECK ENDPOINT --- >>>
@app.post("/generate/")
async def generate_text(request: PromptRequest):
global model, tokenizer, model_loaded_successfully
if not model_loaded_successfully or model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model is not loaded or still loading. Please try again shortly or check server logs.")
try:
inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
print(f"Received prompt: {request.prompt}")
print("Generating...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
num_return_sequences=1,
do_sample=True,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
prompt_tokens = inputs.input_ids.shape[-1]
if outputs[0].size(0) > prompt_tokens:
generated_sequence = outputs[0][prompt_tokens:]
generated_text = tokenizer.decode(generated_sequence, skip_special_tokens=True)
else:
generated_text = ""
print(f"Generated text: {generated_text}")
return {"generated_text": generated_text}
except Exception as e:
print(f"Error during generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
print("Starting Uvicorn server directly from app.py for local testing...")
port = int(os.environ.get("PORT", 8000))
host = "0.0.0.0"
print(f"Uvicorn will attempt to listen on host {host}, port {port}")
print("Set BASE_MODEL_ID and ADAPTER_PATH environment variables for model loading.")
# The @app.on_event("startup") will be called by Uvicorn.
try:
uvicorn.run(app, host=host, port=port)
except Exception as e:
print(f"Error attempting to run uvicorn: {e}")