File size: 8,859 Bytes
a7a61ee
3a336f9
1ac97be
 
f7cc5b0
2fb221d
 
a7a61ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ac97be
3a336f9
 
 
 
 
 
 
 
a1b4668
f7cc5b0
1ac97be
f7cc5b0
1ac97be
 
 
f7cc5b0
a7a61ee
 
 
1ac97be
2fb221d
 
a7a61ee
 
 
 
 
 
 
 
 
 
 
 
2fb221d
 
a7a61ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ac97be
a7a61ee
 
 
 
 
 
 
 
1ac97be
a7a61ee
 
f7cc5b0
a7a61ee
 
 
 
f7cc5b0
a7a61ee
2fb221d
 
a7a61ee
 
 
 
1ac97be
a7a61ee
 
2fb221d
 
a7a61ee
2fb221d
 
 
a7a61ee
 
2fb221d
 
 
 
 
a7a61ee
2fb221d
a7a61ee
 
 
 
2fb221d
 
a7a61ee
2fb221d
a7a61ee
 
2fb221d
 
 
a7a61ee
 
2fb221d
a7a61ee
2fb221d
 
a7a61ee
2fb221d
a7a61ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ac97be
a7a61ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ac97be
a7a61ee
 
f7cc5b0
a7a61ee
2fb221d
a7a61ee
2fb221d
 
 
 
a7a61ee
 
2fb221d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import re
import time
from contextlib import asynccontextmanager

# --- Performance Optimizations & Model Loading ---

# 1. Device Selection: Use CUDA GPU if available for a massive speed boost.
device = "cuda" if torch.cuda.is_available() else "cpu"
# 2. Data Type: Use float16 on GPU for faster computation and less memory usage.
torch_dtype = torch.float16 if device == "cuda" else torch.float32

print(f"--- System Info ---")
print(f"Using device: {device}")
print(f"Using dtype: {torch_dtype}")
print("--------------------")

# --- App State and Model Placeholders ---
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = None
model = None

# --- Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Handles startup and shutdown events.
    Loads the ML model and tokenizer on startup.
    """
    global tokenizer, model
    
    print("Loading model and tokenizer...")
    start_time = time.time()
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Set pad token if it's not already set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    try:
        # 3. Attention Mechanism: Use Flash Attention 2 for a ~2x speedup on compatible GPUs.
        print("Attempting to load model with Flash Attention 2...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            attn_implementation="flash_attention_2"
        ).to(device)
        print("Successfully loaded model with Flash Attention 2.")
    except (ImportError, RuntimeError) as e:
        print(f"Flash Attention 2 not available ({e}), falling back to default attention.")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
        ).to(device)

    # 4. Model Compilation (PyTorch 2.0+): JIT-compiles the model for faster execution.
    print("Compiling model with torch.compile()...")
    try:
        model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
        print("Model compiled successfully.")
    except Exception as e:
        print(f"torch.compile() failed: {e}. Running with uncompiled model.")

    end_time = time.time()
    print(f"Model loading and compilation finished in {end_time - start_time:.2f} seconds.")
    
    yield
    
    # Clean up resources on shutdown (optional)
    print("Cleaning up and shutting down.")
    model = None
    tokenizer = None


# --- FastAPI App Initialization ---
app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)


# --- API Request and Response Models ---
class GenerationRequest(BaseModel):
    llm_commands: list[str]
    batch_size: int = 50

class GenerationResponse(BaseModel):
    data: list
    raw_output: str # Added for debugging
    duration_s: float # Added for performance tracking


# --- Helper Functions ---
def extract_json_from_text(text: str):
    """
    Extracts a JSON array from the model's raw text output.
    This version is more robust and handles incomplete JSON at the end.
    """
    # Find the first '[' and the last ']' to bound the JSON content
    start_bracket = text.find('[')
    end_bracket = text.rfind(']')

    if start_bracket == -1 or end_bracket == -1:
        return None # No JSON array found

    json_str = text[start_bracket : end_bracket + 1]
    
    try:
        # Attempt to parse the primary JSON string
        return json.loads(json_str)
    except json.JSONDecodeError:
        # Fallback for malformed JSON: try to parse line by line
        print("Warning: Initial JSON parsing failed. Attempting to recover partial data.")
        potential_rows = json_str.strip()[1:-1].split('],[')
        valid_rows = []
        for row_str in potential_rows:
            try:
                # Reconstruct and parse each potential row
                clean_row_str = row_str.replace('[', '').replace(']', '').strip()
                if clean_row_str:
                    valid_rows.append(json.loads(f'[{clean_row_str}]'))
            except json.JSONDecodeError:
                continue # Skip malformed rows
        return valid_rows if valid_rows else None


def create_structured_prompt(commands: list[str], batch_size: int) -> str:
    """
    Creates a more structured and forceful prompt to ensure the model returns clean JSON.
    """
    cols_description = '\n'.join([f'- Column {i+1}: {cmd}' for i, cmd in enumerate(commands)])
    return f"""
Generate exactly {batch_size} rows of data.
Each inner array must have exactly {len(commands)} columns.

The columns are defined as follows:
{cols_description}

Your entire response must be ONLY the JSON array of arrays, with no additional text, explanations, or markdown.
Example of a valid response:
[["value1", "value2"], ["value3", "value4"]]
"""

# --- API Endpoints ---
@app.post("/generate", response_model=GenerationResponse)
async def generate_data(request: GenerationRequest):
    if not model or not tokenizer:
         raise HTTPException(status_code=503, detail="Model is not ready. Please try again in a moment.")

    start_time = time.time()
    try:
        # Create a more reliable prompt
        prompt = create_structured_prompt(request.llm_commands, request.batch_size)
        
        messages = [
            {"role": "system", "content": "You are a precise data generation machine. Your sole purpose is to return a valid JSON array of arrays. You will not deviate from this role."},
            {"role": "user", "content": prompt}
        ]
        
        # Apply the chat template
        text_input = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        model_inputs = tokenizer([text_input], return_tensors="pt").to(device)
        
        # Generate with no_grad context for better performance
        with torch.no_grad():
            # Dynamically set max_new_tokens based on expected output size with a buffer
            max_new_tokens = int(request.batch_size * len(request.llm_commands) * 10 + 50)
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=min(4096, max_new_tokens),
                do_sample=True,
                temperature=0.7,
                top_p=0.95,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        # Decode the output
        response_text = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
        
        # Extract and validate JSON data
        json_data = extract_json_from_text(response_text)
        
        final_data = []
        if json_data and isinstance(json_data, list):
            expected_cols = len(request.llm_commands)
            # Filter for valid rows and cap at the requested batch size
            final_data = [
                row for row in json_data 
                if isinstance(row, list) and len(row) == expected_cols
            ][:request.batch_size]
        else:
            print(f"Failed to parse JSON. Raw output: {response_text}")

        end_time = time.time()
        return {
            "data": final_data,
            "raw_output": response_text,
            "duration_s": round(end_time - start_time, 2)
        }

    except Exception as e:
        print(f"An error occurred during generation: {e}")
        raise HTTPException(status_code=500, detail=str(e))

# --- New Test Route ---
@app.get("/test", response_model=GenerationResponse, summary="Run a predefined test generation")
async def test_generation():
    """
    A simple test endpoint that generates 10 rows of sample data with fixed commands.
    This allows for easy performance testing and validation.
    """
    test_request = GenerationRequest(
        llm_commands=[
            "a common first name starting with the letter A",
            "an age as an integer between 20 and 30"
        ],
        batch_size=10
    )
    print("--- Running /test endpoint ---")
    return await generate_data(test_request)


# --- Health and Status Routes ---
@app.get("/", summary="Root status check")
def read_root():
    return {"status": "ok", "model_name": model_name, "device": device}

@app.get("/health", summary="Health check for the service")
def health_check():
    return {
        "status": "healthy",
        "model_loaded": model is not None,
        "tokenizer_loaded": tokenizer is not None,
        "device": device
    }