Spaces:
Paused
Paused
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig | |
| from peft import PeftModel | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import os | |
| import time # For checking model load status | |
| # --- Global Variables for Model and Tokenizer --- | |
| model = None | |
| tokenizer = None | |
| model_loaded_successfully = False # Flag to indicate model status | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"--- Initializing on Device: {device} ---") | |
| # --- Pydantic Model for Request Body --- | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| max_new_tokens: int = 256 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| top_k: int = 50 | |
| # --- FastAPI App Initialization --- | |
| app = FastAPI() | |
| def load_model_and_tokenizer(): | |
| global model, tokenizer, model_loaded_successfully | |
| base_model_id = os.environ.get("BASE_MODEL_ID") | |
| adapter_path = os.environ.get("ADAPTER_PATH") | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not base_model_id: | |
| print("CRITICAL ERROR: BASE_MODEL_ID environment variable not set.") | |
| # In a real app, you might want to prevent startup or handle this more gracefully | |
| return | |
| if not adapter_path: | |
| print("CRITICAL ERROR: ADAPTER_PATH environment variable not set.") | |
| return | |
| print(f"Using device: {device}") | |
| print(f"Attempting to load base model: {base_model_id}") | |
| print(f"Attempting to load adapter from: {adapter_path}") | |
| try: | |
| # --- Load Tokenizer --- | |
| print(f"Loading tokenizer...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(adapter_path, token=hf_token, trust_remote_code=True) | |
| print(f"Loaded tokenizer from adapter path: {adapter_path}") | |
| except Exception as e: | |
| print(f"Could not load tokenizer from adapter path: {e}. Loading from base model path: {base_model_id}") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| if tokenizer.eos_token is not None: | |
| print("Setting pad_token to eos_token.") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| else: | |
| print("Adding new pad_token '[PAD]'.") | |
| tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| tokenizer.padding_side = "left" | |
| # --- Configure Quantization --- | |
| print("Configuring 4-bit quantization...") | |
| compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() and device == "cuda" else torch.float16 | |
| bnb_config = None | |
| if device == "cuda": | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| print(f"Using BNB config with compute_dtype: {compute_dtype}") | |
| else: | |
| print("Running on CPU, BNB quantization will not be applied.") | |
| # --- Load Base Model with Quantization --- | |
| print(f"Loading base model: {base_model_id}...") | |
| config = AutoConfig.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True) | |
| if getattr(config, "pretraining_tp", 1) != 1: | |
| print(f"Overriding pretraining_tp from {getattr(config, 'pretraining_tp', 'N/A')} to 1.") | |
| config.pretraining_tp = 1 | |
| base_model_instance = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| config=config, | |
| quantization_config=bnb_config if device == "cuda" else None, | |
| device_map={"": device}, | |
| token=hf_token, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True if device == "cuda" else False | |
| ) | |
| print("Base model loaded.") | |
| if tokenizer.pad_token_id is not None and tokenizer.pad_token_id >= base_model_instance.config.vocab_size: | |
| print("Resizing token embeddings for base model.") | |
| base_model_instance.resize_token_embeddings(len(tokenizer)) | |
| # --- Load LoRA Adapter --- | |
| print(f"Loading LoRA adapter from: {adapter_path}...") | |
| model = PeftModel.from_pretrained(base_model_instance, adapter_path) | |
| model.eval() | |
| print("LoRA adapter loaded and model is in eval mode.") | |
| print(f"Model is on device: {model.device}") | |
| model_loaded_successfully = True # Set flag on successful load | |
| print("Model and tokenizer loaded successfully.") | |
| except Exception as e: | |
| print(f"CRITICAL ERROR during model/tokenizer loading: {e}") | |
| model_loaded_successfully = False | |
| # Optionally, re-raise or handle to prevent app from starting if model load fails. | |
| # For now, it will print error and the /generate endpoint will show model not loaded. | |
| # And the health check will show model not ready. | |
| async def startup_event(): | |
| print("Server startup event: Initiating model and tokenizer loading...") | |
| # Model loading can take time, so it's done here. | |
| # Health checks might hit the server before this completes. | |
| load_model_and_tokenizer() | |
| if model_loaded_successfully: | |
| print("Model loading process completed successfully within startup event.") | |
| else: | |
| print("Model loading process encountered an error or did not complete within startup event.") | |
| # <<< --- ADDED HEALTH CHECK ENDPOINT --- >>> | |
| async def health_check(): | |
| """Basic health check endpoint.""" | |
| if model_loaded_successfully and model is not None and tokenizer is not None: | |
| return {"status": "ok", "message": "Model is loaded and ready."} | |
| else: | |
| # Return a 503 if model isn't ready yet, so Spaces knows it's still starting up | |
| # or if loading failed. | |
| raise HTTPException(status_code=503, detail="Model is not loaded or still loading.") | |
| # Common alternative health check path | |
| async def health_check_alternative(): | |
| return await health_check() | |
| # <<< --- END OF HEALTH CHECK ENDPOINT --- >>> | |
| async def generate_text(request: PromptRequest): | |
| global model, tokenizer, model_loaded_successfully | |
| if not model_loaded_successfully or model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model is not loaded or still loading. Please try again shortly or check server logs.") | |
| try: | |
| inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
| print(f"Received prompt: {request.prompt}") | |
| print("Generating...") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=request.max_new_tokens, | |
| num_return_sequences=1, | |
| do_sample=True, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| top_k=request.top_k, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| prompt_tokens = inputs.input_ids.shape[-1] | |
| if outputs[0].size(0) > prompt_tokens: | |
| generated_sequence = outputs[0][prompt_tokens:] | |
| generated_text = tokenizer.decode(generated_sequence, skip_special_tokens=True) | |
| else: | |
| generated_text = "" | |
| print(f"Generated text: {generated_text}") | |
| return {"generated_text": generated_text} | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| print("Starting Uvicorn server directly from app.py for local testing...") | |
| port = int(os.environ.get("PORT", 8000)) | |
| host = "0.0.0.0" | |
| print(f"Uvicorn will attempt to listen on host {host}, port {port}") | |
| print("Set BASE_MODEL_ID and ADAPTER_PATH environment variables for model loading.") | |
| # The @app.on_event("startup") will be called by Uvicorn. | |
| try: | |
| uvicorn.run(app, host=host, port=port) | |
| except Exception as e: | |
| print(f"Error attempting to run uvicorn: {e}") | |