Spaces:
Sleeping
Sleeping
File size: 8,859 Bytes
a7a61ee 3a336f9 1ac97be f7cc5b0 2fb221d a7a61ee 1ac97be 3a336f9 a1b4668 f7cc5b0 1ac97be f7cc5b0 1ac97be f7cc5b0 a7a61ee 1ac97be 2fb221d a7a61ee 2fb221d a7a61ee 1ac97be a7a61ee 1ac97be a7a61ee f7cc5b0 a7a61ee f7cc5b0 a7a61ee 2fb221d a7a61ee 1ac97be a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 1ac97be a7a61ee 1ac97be a7a61ee f7cc5b0 a7a61ee 2fb221d a7a61ee 2fb221d a7a61ee 2fb221d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import re
import time
from contextlib import asynccontextmanager
# --- Performance Optimizations & Model Loading ---
# 1. Device Selection: Use CUDA GPU if available for a massive speed boost.
device = "cuda" if torch.cuda.is_available() else "cpu"
# 2. Data Type: Use float16 on GPU for faster computation and less memory usage.
torch_dtype = torch.float16 if device == "cuda" else torch.float32
print(f"--- System Info ---")
print(f"Using device: {device}")
print(f"Using dtype: {torch_dtype}")
print("--------------------")
# --- App State and Model Placeholders ---
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = None
model = None
# --- Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown events.
Loads the ML model and tokenizer on startup.
"""
global tokenizer, model
print("Loading model and tokenizer...")
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set pad token if it's not already set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
try:
# 3. Attention Mechanism: Use Flash Attention 2 for a ~2x speedup on compatible GPUs.
print("Attempting to load model with Flash Attention 2...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2"
).to(device)
print("Successfully loaded model with Flash Attention 2.")
except (ImportError, RuntimeError) as e:
print(f"Flash Attention 2 not available ({e}), falling back to default attention.")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
).to(device)
# 4. Model Compilation (PyTorch 2.0+): JIT-compiles the model for faster execution.
print("Compiling model with torch.compile()...")
try:
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
print("Model compiled successfully.")
except Exception as e:
print(f"torch.compile() failed: {e}. Running with uncompiled model.")
end_time = time.time()
print(f"Model loading and compilation finished in {end_time - start_time:.2f} seconds.")
yield
# Clean up resources on shutdown (optional)
print("Cleaning up and shutting down.")
model = None
tokenizer = None
# --- FastAPI App Initialization ---
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
# --- API Request and Response Models ---
class GenerationRequest(BaseModel):
llm_commands: list[str]
batch_size: int = 50
class GenerationResponse(BaseModel):
data: list
raw_output: str # Added for debugging
duration_s: float # Added for performance tracking
# --- Helper Functions ---
def extract_json_from_text(text: str):
"""
Extracts a JSON array from the model's raw text output.
This version is more robust and handles incomplete JSON at the end.
"""
# Find the first '[' and the last ']' to bound the JSON content
start_bracket = text.find('[')
end_bracket = text.rfind(']')
if start_bracket == -1 or end_bracket == -1:
return None # No JSON array found
json_str = text[start_bracket : end_bracket + 1]
try:
# Attempt to parse the primary JSON string
return json.loads(json_str)
except json.JSONDecodeError:
# Fallback for malformed JSON: try to parse line by line
print("Warning: Initial JSON parsing failed. Attempting to recover partial data.")
potential_rows = json_str.strip()[1:-1].split('],[')
valid_rows = []
for row_str in potential_rows:
try:
# Reconstruct and parse each potential row
clean_row_str = row_str.replace('[', '').replace(']', '').strip()
if clean_row_str:
valid_rows.append(json.loads(f'[{clean_row_str}]'))
except json.JSONDecodeError:
continue # Skip malformed rows
return valid_rows if valid_rows else None
def create_structured_prompt(commands: list[str], batch_size: int) -> str:
"""
Creates a more structured and forceful prompt to ensure the model returns clean JSON.
"""
cols_description = '\n'.join([f'- Column {i+1}: {cmd}' for i, cmd in enumerate(commands)])
return f"""
Generate exactly {batch_size} rows of data.
Each inner array must have exactly {len(commands)} columns.
The columns are defined as follows:
{cols_description}
Your entire response must be ONLY the JSON array of arrays, with no additional text, explanations, or markdown.
Example of a valid response:
[["value1", "value2"], ["value3", "value4"]]
"""
# --- API Endpoints ---
@app.post("/generate", response_model=GenerationResponse)
async def generate_data(request: GenerationRequest):
if not model or not tokenizer:
raise HTTPException(status_code=503, detail="Model is not ready. Please try again in a moment.")
start_time = time.time()
try:
# Create a more reliable prompt
prompt = create_structured_prompt(request.llm_commands, request.batch_size)
messages = [
{"role": "system", "content": "You are a precise data generation machine. Your sole purpose is to return a valid JSON array of arrays. You will not deviate from this role."},
{"role": "user", "content": prompt}
]
# Apply the chat template
text_input = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text_input], return_tensors="pt").to(device)
# Generate with no_grad context for better performance
with torch.no_grad():
# Dynamically set max_new_tokens based on expected output size with a buffer
max_new_tokens = int(request.batch_size * len(request.llm_commands) * 10 + 50)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=min(4096, max_new_tokens),
do_sample=True,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
)
# Decode the output
response_text = tokenizer.batch_decode(generated_ids[:, model_inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
# Extract and validate JSON data
json_data = extract_json_from_text(response_text)
final_data = []
if json_data and isinstance(json_data, list):
expected_cols = len(request.llm_commands)
# Filter for valid rows and cap at the requested batch size
final_data = [
row for row in json_data
if isinstance(row, list) and len(row) == expected_cols
][:request.batch_size]
else:
print(f"Failed to parse JSON. Raw output: {response_text}")
end_time = time.time()
return {
"data": final_data,
"raw_output": response_text,
"duration_s": round(end_time - start_time, 2)
}
except Exception as e:
print(f"An error occurred during generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
# --- New Test Route ---
@app.get("/test", response_model=GenerationResponse, summary="Run a predefined test generation")
async def test_generation():
"""
A simple test endpoint that generates 10 rows of sample data with fixed commands.
This allows for easy performance testing and validation.
"""
test_request = GenerationRequest(
llm_commands=[
"a common first name starting with the letter A",
"an age as an integer between 20 and 30"
],
batch_size=10
)
print("--- Running /test endpoint ---")
return await generate_data(test_request)
# --- Health and Status Routes ---
@app.get("/", summary="Root status check")
def read_root():
return {"status": "ok", "model_name": model_name, "device": device}
@app.get("/health", summary="Health check for the service")
def health_check():
return {
"status": "healthy",
"model_loaded": model is not None,
"tokenizer_loaded": tokenizer is not None,
"device": device
} |