File size: 8,310 Bytes
b291059
 
 
 
 
 
 
3bd8666
b291059
 
 
 
3bd8666
b291059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bd8666
b291059
 
 
db4b4c9
b291059
 
3bd8666
 
 
b291059
3bd8666
 
b291059
 
 
 
 
 
3bd8666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b291059
3bd8666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b291059
3bd8666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b291059
 
 
3bd8666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b291059
 
 
3bd8666
 
db4b4c9
b291059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db4b4c9
b291059
 
 
 
 
 
 
 
 
3bd8666
db4b4c9
3bd8666
db4b4c9
3bd8666
b291059
3bd8666
db4b4c9
 
 
 
b291059
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
from peft import PeftModel
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import os
import time # For checking model load status

# --- Global Variables for Model and Tokenizer ---
model = None
tokenizer = None
model_loaded_successfully = False # Flag to indicate model status
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--- Initializing on Device: {device} ---")

# --- Pydantic Model for Request Body ---
class PromptRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 256
    temperature: float = 0.7
    top_p: float = 0.9
    top_k: int = 50

# --- FastAPI App Initialization ---
app = FastAPI()

def load_model_and_tokenizer():
    global model, tokenizer, model_loaded_successfully

    base_model_id = os.environ.get("BASE_MODEL_ID")
    adapter_path = os.environ.get("ADAPTER_PATH")
    hf_token = os.environ.get("HF_TOKEN") 

    if not base_model_id:
        print("CRITICAL ERROR: BASE_MODEL_ID environment variable not set.")
        # In a real app, you might want to prevent startup or handle this more gracefully
        return 
    if not adapter_path:
        print("CRITICAL ERROR: ADAPTER_PATH environment variable not set.")
        return

    print(f"Using device: {device}")
    print(f"Attempting to load base model: {base_model_id}")
    print(f"Attempting to load adapter from: {adapter_path}")

    try:
        # --- Load Tokenizer ---
        print(f"Loading tokenizer...")
        try:
            tokenizer = AutoTokenizer.from_pretrained(adapter_path, token=hf_token, trust_remote_code=True)
            print(f"Loaded tokenizer from adapter path: {adapter_path}")
        except Exception as e:
            print(f"Could not load tokenizer from adapter path: {e}. Loading from base model path: {base_model_id}")
            tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)

        if tokenizer.pad_token is None:
            if tokenizer.eos_token is not None:
                print("Setting pad_token to eos_token.")
                tokenizer.pad_token = tokenizer.eos_token
            else:
                print("Adding new pad_token '[PAD]'.")
                tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.padding_side = "left"

        # --- Configure Quantization ---
        print("Configuring 4-bit quantization...")
        compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() and device == "cuda" else torch.float16
        
        bnb_config = None
        if device == "cuda":
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=compute_dtype,
                bnb_4bit_use_double_quant=True,
            )
            print(f"Using BNB config with compute_dtype: {compute_dtype}")
        else:
            print("Running on CPU, BNB quantization will not be applied.")

        # --- Load Base Model with Quantization ---
        print(f"Loading base model: {base_model_id}...")
        config = AutoConfig.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True)
        if getattr(config, "pretraining_tp", 1) != 1:
            print(f"Overriding pretraining_tp from {getattr(config, 'pretraining_tp', 'N/A')} to 1.")
            config.pretraining_tp = 1
        
        base_model_instance = AutoModelForCausalLM.from_pretrained(
            base_model_id,
            config=config,
            quantization_config=bnb_config if device == "cuda" else None,
            device_map={"": device},
            token=hf_token,
            trust_remote_code=True,
            low_cpu_mem_usage=True if device == "cuda" else False
        )
        print("Base model loaded.")

        if tokenizer.pad_token_id is not None and tokenizer.pad_token_id >= base_model_instance.config.vocab_size:
            print("Resizing token embeddings for base model.")
            base_model_instance.resize_token_embeddings(len(tokenizer))

        # --- Load LoRA Adapter ---
        print(f"Loading LoRA adapter from: {adapter_path}...")
        model = PeftModel.from_pretrained(base_model_instance, adapter_path)
        model.eval()
        print("LoRA adapter loaded and model is in eval mode.")
        print(f"Model is on device: {model.device}")
        model_loaded_successfully = True # Set flag on successful load
        print("Model and tokenizer loaded successfully.")

    except Exception as e:
        print(f"CRITICAL ERROR during model/tokenizer loading: {e}")
        model_loaded_successfully = False
        # Optionally, re-raise or handle to prevent app from starting if model load fails.
        # For now, it will print error and the /generate endpoint will show model not loaded.
        # And the health check will show model not ready.

@app.on_event("startup")
async def startup_event():
    print("Server startup event: Initiating model and tokenizer loading...")
    # Model loading can take time, so it's done here.
    # Health checks might hit the server before this completes.
    load_model_and_tokenizer()
    if model_loaded_successfully:
        print("Model loading process completed successfully within startup event.")
    else:
        print("Model loading process encountered an error or did not complete within startup event.")


# <<< --- ADDED HEALTH CHECK ENDPOINT --- >>>
@app.get("/")
async def health_check():
    """Basic health check endpoint."""
    if model_loaded_successfully and model is not None and tokenizer is not None:
        return {"status": "ok", "message": "Model is loaded and ready."}
    else:
        # Return a 503 if model isn't ready yet, so Spaces knows it's still starting up
        # or if loading failed.
        raise HTTPException(status_code=503, detail="Model is not loaded or still loading.")

@app.get("/health") # Common alternative health check path
async def health_check_alternative():
    return await health_check()
# <<< --- END OF HEALTH CHECK ENDPOINT --- >>>


@app.post("/generate/")
async def generate_text(request: PromptRequest):
    global model, tokenizer, model_loaded_successfully
    if not model_loaded_successfully or model is None or tokenizer is None:
        raise HTTPException(status_code=503, detail="Model is not loaded or still loading. Please try again shortly or check server logs.")

    try:
        inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        
        print(f"Received prompt: {request.prompt}")
        print("Generating...")
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=request.max_new_tokens,
                num_return_sequences=1,
                do_sample=True,
                temperature=request.temperature,
                top_p=request.top_p,
                top_k=request.top_k,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        
        prompt_tokens = inputs.input_ids.shape[-1]
        if outputs[0].size(0) > prompt_tokens:
            generated_sequence = outputs[0][prompt_tokens:]
            generated_text = tokenizer.decode(generated_sequence, skip_special_tokens=True)
        else: 
            generated_text = "" 
            
        print(f"Generated text: {generated_text}")
        return {"generated_text": generated_text}
    except Exception as e:
        print(f"Error during generation: {e}")
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    print("Starting Uvicorn server directly from app.py for local testing...")
    port = int(os.environ.get("PORT", 8000))
    host = "0.0.0.0"
    print(f"Uvicorn will attempt to listen on host {host}, port {port}")
    print("Set BASE_MODEL_ID and ADAPTER_PATH environment variables for model loading.")
    
    # The @app.on_event("startup") will be called by Uvicorn.
    try:
        uvicorn.run(app, host=host, port=port)
    except Exception as e:
        print(f"Error attempting to run uvicorn: {e}")