abhisheksan commited on
Commit
a22b331
·
verified ·
1 Parent(s): 2b79bbe

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +437 -155
main.py CHANGED
@@ -1,39 +1,65 @@
1
  import os
2
- from typing import Optional, Dict, Any, List
3
- from fastapi import FastAPI, HTTPException, status, BackgroundTasks
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.responses import JSONResponse
6
  import logging
7
  import sys
8
- from pydantic import BaseModel, Field, validator
 
 
 
9
  import torch
10
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
11
- from contextlib import asynccontextmanager
12
  import asyncio
13
- from functools import lru_cache
14
  import numpy as np
15
- from datetime import datetime
16
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Constants
19
- BASE_MODEL_DIR = "./models/"
20
- MODEL_PATH = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth")
21
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- BATCH_SIZE = 4
23
- CACHE_SIZE = 1024
24
 
25
- MODEL_CONFIG = GPT2Config(
26
- n_positions=400,
27
- n_ctx=400,
28
- n_embd=384,
29
- n_layer=6,
30
- n_head=6,
31
- vocab_size=50257,
32
- bos_token_id=50256,
33
- eos_token_id=50256,
34
- use_cache=True,
 
 
35
  )
 
36
 
 
37
  class GenerateRequest(BaseModel):
38
  prompt: str = Field(..., min_length=1, max_length=500)
39
  max_length: Optional[int] = Field(default=100, ge=10, le=500)
@@ -46,16 +72,20 @@ class GenerateRequest(BaseModel):
46
 
47
  @validator('prompt')
48
  def validate_prompt(cls, v):
 
49
  v = ' '.join(v.split())
50
  return v
51
 
 
52
  class PoemFormatter:
53
- """Handles poem formatting and processing"""
54
 
55
  @staticmethod
56
  def format_free_verse(text: str) -> List[str]:
 
57
  lines = re.split(r'[.!?]+|\n+', text)
58
  lines = [line.strip() for line in lines if line.strip()]
 
59
  formatted_lines = []
60
  for line in lines:
61
  if len(line) > 40:
@@ -63,24 +93,31 @@ class PoemFormatter:
63
  formatted_lines.extend(part.strip() for part in parts if part.strip())
64
  else:
65
  formatted_lines.append(line)
 
66
  return formatted_lines
67
 
68
  @staticmethod
69
  def format_haiku(text: str) -> List[str]:
 
 
 
70
  words = text.split()
71
  lines = []
72
  current_line = []
73
  syllable_count = 0
74
 
 
 
 
75
  for word in words:
76
- syllables = len(re.findall(r'[aeiou]+', word.lower()))
77
- if syllable_count + syllables <= 5 and len(lines) == 0:
78
- current_line.append(word)
79
- syllable_count += syllables
80
- elif syllable_count + syllables <= 7 and len(lines) == 1:
81
- current_line.append(word)
82
- syllable_count += syllables
83
- elif syllable_count + syllables <= 5 and len(lines) == 2:
84
  current_line.append(word)
85
  syllable_count += syllables
86
  else:
@@ -88,13 +125,15 @@ class PoemFormatter:
88
  lines.append(' '.join(current_line))
89
  current_line = [word]
90
  syllable_count = syllables
91
-
92
- if len(lines) == 3:
93
- break
94
-
95
- if current_line and len(lines) < 3:
96
  lines.append(' '.join(current_line))
97
 
 
 
 
 
98
  return lines[:3]
99
 
100
  @staticmethod
@@ -102,7 +141,7 @@ class PoemFormatter:
102
  words = text.split()
103
  lines = []
104
  current_line = []
105
- target_line_length = 10
106
 
107
  for word in words:
108
  current_line.append(word)
@@ -110,33 +149,112 @@ class PoemFormatter:
110
  lines.append(' '.join(current_line))
111
  current_line = []
112
 
113
- if len(lines) >= 14:
114
  break
115
 
116
  if current_line and len(lines) < 14:
117
  lines.append(' '.join(current_line))
118
 
119
- return lines[:14]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
121
  class ModelManager:
122
  def __init__(self):
123
  self.model = None
124
- self.tokenizer = None
125
  self._lock = asyncio.Lock()
126
  self.request_count = 0
127
  self.last_cleanup = datetime.now()
 
 
 
128
  self.poem_formatter = PoemFormatter()
 
129
 
130
  async def initialize(self) -> bool:
131
  try:
132
- self._setup_logging()
133
-
134
- logger.info(f"Initializing model on device: {DEVICE}")
135
 
136
- self.tokenizer = await self._load_tokenizer()
137
  await self._load_and_optimize_model()
138
 
139
- logger.info("Model and tokenizer loaded successfully")
 
 
 
 
 
 
 
 
 
140
  return True
141
 
142
  except Exception as e:
@@ -144,92 +262,179 @@ class ModelManager:
144
  logger.exception("Detailed traceback:")
145
  return False
146
 
147
- @staticmethod
148
- def _setup_logging():
149
- global logger
150
- logger = logging.getLogger(__name__)
151
- logger.setLevel(logging.INFO)
152
-
153
- formatter = logging.Formatter(
154
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
155
- )
156
-
157
- handlers = [logging.StreamHandler(sys.stdout)]
158
-
159
  try:
160
- log_dir = os.path.join(os.getcwd(), 'logs')
161
- os.makedirs(log_dir, exist_ok=True)
162
- handlers.append(logging.FileHandler(
163
- os.path.join(log_dir, f'poetry_generation_{datetime.now().strftime("%Y%m%d")}.log')
164
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  except Exception as e:
166
- print(f"Warning: Could not create log file: {e}")
167
-
168
- for handler in handlers:
169
- handler.setFormatter(formatter)
170
- logger.addHandler(handler)
171
 
172
- @lru_cache(maxsize=CACHE_SIZE)
173
- async def _load_tokenizer(self):
174
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
175
- tokenizer.pad_token = tokenizer.eos_token
176
- return tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  async def _load_and_optimize_model(self):
179
- if not os.path.exists(MODEL_PATH):
180
- raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
181
-
182
- self.model = GPT2LMHeadModel(MODEL_CONFIG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
185
- self.model.load_state_dict(state_dict, strict=False)
 
 
 
186
 
187
- self.model.to(DEVICE)
188
- self.model.eval()
 
 
 
189
 
190
- if DEVICE.type == 'cuda':
191
- torch.backends.cudnn.benchmark = True
192
- self.model = torch.jit.script(self.model)
193
-
194
- dummy_input = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
195
- with torch.no_grad():
196
- self.model(dummy_input)
197
-
198
- @torch.no_grad()
199
- async def generate(self, request: GenerateRequest) -> Dict[str, Any]:
200
- async with self._lock:
201
  try:
202
- self.request_count += 1
203
- await self._check_cleanup()
204
-
205
- inputs = await self._prepare_inputs(request.prompt)
206
- outputs = await self._generate_optimized(inputs, request)
207
-
208
- return await self._process_outputs(outputs, request)
209
-
210
  except Exception as e:
211
- logger.error(f"Error generating text: {str(e)}")
212
- raise HTTPException(
213
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
214
- detail=str(e)
215
- )
216
 
217
  async def _prepare_inputs(self, prompt: str):
 
218
  poetry_prompt = f"Write a poem about: {prompt}\n\nPoem:"
219
- tokens = self.tokenizer.encode(poetry_prompt, return_tensors='pt')
220
- return tokens.to(DEVICE)
221
 
222
  async def _generate_optimized(self, inputs, request: GenerateRequest):
223
- attention_mask = torch.ones(inputs.shape, dtype=torch.long, device=DEVICE)
 
224
 
 
225
  style_params = {
226
- "haiku": {"max_length": 50, "repetition_penalty": 1.3},
227
- "sonnet": {"max_length": 200, "repetition_penalty": 1.2},
228
- "free_verse": {"max_length": request.max_length, "repetition_penalty": request.repetition_penalty}
 
 
 
 
229
  }
230
 
231
  params = style_params.get(request.style, style_params["free_verse"])
232
 
 
 
 
 
 
233
  return self.model.generate(
234
  inputs,
235
  attention_mask=attention_mask,
@@ -240,31 +445,35 @@ class ModelManager:
240
  top_p=request.top_p,
241
  repetition_penalty=params["repetition_penalty"],
242
  do_sample=True,
243
- pad_token_id=self.tokenizer.eos_token_id,
244
  use_cache=True,
245
- no_repeat_ngram_size=3,
246
  early_stopping=True,
247
- bad_words_ids=[[self.tokenizer.encode(word)[0]] for word in
248
- ['http', 'www', 'com', ':', '/', '#']],
249
- min_length=20,
250
  )
251
 
252
  async def _process_outputs(self, outputs, request: GenerateRequest):
253
- raw_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
254
 
 
255
  prompt_pattern = f"Write a poem about: {request.prompt}\n\nPoem:"
256
  poem_text = raw_text.replace(prompt_pattern, '').strip()
257
 
 
258
  if request.style == "haiku":
259
- formatted_lines = PoemFormatter.format_haiku(poem_text)
260
  elif request.style == "sonnet":
261
- formatted_lines = PoemFormatter.format_sonnet(poem_text)
262
  else:
263
- formatted_lines = PoemFormatter.format_free_verse(poem_text)
264
 
 
265
  return {
266
  "poem": {
267
- "title": self._generate_title(poem_text),
268
  "lines": formatted_lines,
269
  "style": request.style
270
  },
@@ -277,47 +486,92 @@ class ModelManager:
277
  "repetition_penalty": request.repetition_penalty
278
  },
279
  "metadata": {
280
- "device": DEVICE.type,
281
- "model_type": "GPT2",
282
  "timestamp": datetime.now().isoformat()
283
  }
284
  }
285
 
286
- def _generate_title(self, poem_text: str) -> str:
287
- words = poem_text.split()[:6]
288
- stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'}
289
- key_words = [word for word in words if word.lower() not in stop_words]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- if key_words:
292
- title = ' '.join(key_words[:3]).capitalize()
293
- return title
294
- return "Untitled"
 
 
 
 
 
 
 
295
 
296
- async def _check_cleanup(self):
297
- if self.request_count % 100 == 0:
298
- if DEVICE.type == 'cuda':
299
- torch.cuda.empty_cache()
300
- self.last_cleanup = datetime.now()
301
 
 
302
  @asynccontextmanager
303
  async def lifespan(app: FastAPI):
304
- if not await model_manager.initialize():
 
 
305
  logger.error("Failed to initialize model manager")
 
306
  yield
307
- if model_manager.model is not None:
308
- del model_manager.model
309
- if model_manager.tokenizer is not None:
310
- del model_manager.tokenizer
311
- if DEVICE.type == 'cuda':
312
- torch.cuda.empty_cache()
313
-
314
  app = FastAPI(
315
  title="Poetry Generation API",
316
- description="Optimized API for generating poetry using GPT-2",
317
- version="2.0.0",
318
  lifespan=lifespan
319
  )
320
 
 
321
  app.add_middleware(
322
  CORSMiddleware,
323
  allow_origins=["*"],
@@ -326,23 +580,30 @@ app.add_middleware(
326
  allow_headers=["*"],
327
  )
328
 
329
- model_manager = ModelManager()
330
-
331
  @app.api_route("/health", methods=["GET", "HEAD"])
332
  async def health_check():
333
  return {
334
  "status": "healthy",
335
  "model_loaded": model_manager.model is not None,
336
- "tokenizer_loaded": model_manager.tokenizer is not None,
337
- "device": DEVICE.type,
 
338
  "request_count": model_manager.request_count,
 
339
  "last_cleanup": model_manager.last_cleanup.isoformat(),
340
  "system_info": {
341
  "cuda_available": torch.cuda.is_available(),
342
  "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
 
 
 
 
 
343
  }
344
  }
345
 
 
346
  @app.post("/generate")
347
  async def generate_text(
348
  request: GenerateRequest,
@@ -351,16 +612,37 @@ async def generate_text(
351
  try:
352
  result = await model_manager.generate(request)
353
 
354
- if model_manager.request_count % 100 == 0:
355
- background_tasks.add_task(torch.cuda.empty_cache)
 
356
 
357
  return JSONResponse(
358
  content=result,
359
  status_code=status.HTTP_200_OK
360
  )
 
 
 
361
  except Exception as e:
362
  logger.error(f"Error in generate_text: {str(e)}")
363
  raise HTTPException(
364
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
365
  detail=str(e)
366
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
2
  import logging
3
  import sys
4
+ from datetime import datetime
5
+ from typing import Optional, Dict, Any, List
6
+ from functools import lru_cache
7
+
8
  import torch
 
 
9
  import asyncio
 
10
  import numpy as np
 
11
  import re
12
+ from fastapi import FastAPI, HTTPException, status, BackgroundTasks, Depends
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from fastapi.responses import JSONResponse
15
+ from pydantic import BaseModel, Field, validator
16
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
17
+ from contextlib import asynccontextmanager
18
+
19
+ # Configuration
20
+ class Config:
21
+ BASE_MODEL_DIR = "./models/"
22
+ MODEL_PATH = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth")
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ BATCH_SIZE = 8 # Increased batch size for better throughput
25
+ CACHE_SIZE = 2048 # Increased cache size
26
+ MAX_QUEUE_SIZE = 16 # Maximum number of requests to queue
27
+ QUANTIZE_MODEL = True # Enable quantization for improved performance
28
+ WARMUP_INPUTS = True # Pre-warm the model with sample inputs
29
+ LOG_DIR = os.path.join(os.getcwd(), 'logs')
30
+ ENABLE_PROFILING = False # Set to True to enable performance profiling
31
+ REQUEST_TIMEOUT = 30.0 # Timeout for request processing in seconds
32
+
33
+ MODEL_CONFIG = GPT2Config(
34
+ n_positions=400,
35
+ n_ctx=400,
36
+ n_embd=384,
37
+ n_layer=6,
38
+ n_head=6,
39
+ vocab_size=50257,
40
+ bos_token_id=50256,
41
+ eos_token_id=50256,
42
+ use_cache=True,
43
+ )
44
 
45
+ config = Config()
 
 
 
 
 
46
 
47
+ # Configure logging
48
+ os.makedirs(config.LOG_DIR, exist_ok=True)
49
+ logging.basicConfig(
50
+ level=logging.INFO,
51
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
52
+ handlers=[
53
+ logging.StreamHandler(sys.stdout),
54
+ logging.FileHandler(os.path.join(
55
+ config.LOG_DIR,
56
+ f'poetry_generation_{datetime.now().strftime("%Y%m%d")}.log'
57
+ ))
58
+ ]
59
  )
60
+ logger = logging.getLogger(__name__)
61
 
62
+ # Request models
63
  class GenerateRequest(BaseModel):
64
  prompt: str = Field(..., min_length=1, max_length=500)
65
  max_length: Optional[int] = Field(default=100, ge=10, le=500)
 
72
 
73
  @validator('prompt')
74
  def validate_prompt(cls, v):
75
+ # Normalize whitespace
76
  v = ' '.join(v.split())
77
  return v
78
 
79
+ # Poem formatting module
80
  class PoemFormatter:
81
+ """Efficient poem formatter with optimized text processing"""
82
 
83
  @staticmethod
84
  def format_free_verse(text: str) -> List[str]:
85
+ # More efficient regex splitting
86
  lines = re.split(r'[.!?]+|\n+', text)
87
  lines = [line.strip() for line in lines if line.strip()]
88
+
89
  formatted_lines = []
90
  for line in lines:
91
  if len(line) > 40:
 
93
  formatted_lines.extend(part.strip() for part in parts if part.strip())
94
  else:
95
  formatted_lines.append(line)
96
+
97
  return formatted_lines
98
 
99
  @staticmethod
100
  def format_haiku(text: str) -> List[str]:
101
+ # Precompile regex for performance
102
+ vowel_pattern = re.compile(r'[aeiou]+')
103
+
104
  words = text.split()
105
  lines = []
106
  current_line = []
107
  syllable_count = 0
108
 
109
+ syllable_targets = [5, 7, 5] # Traditional haiku structure
110
+ current_target_idx = 0
111
+
112
  for word in words:
113
+ syllables = len(vowel_pattern.findall(word.lower())) or 1 # Ensure at least 1 syllable
114
+
115
+ if current_target_idx >= len(syllable_targets):
116
+ break
117
+
118
+ current_target = syllable_targets[current_target_idx]
119
+
120
+ if syllable_count + syllables <= current_target:
121
  current_line.append(word)
122
  syllable_count += syllables
123
  else:
 
125
  lines.append(' '.join(current_line))
126
  current_line = [word]
127
  syllable_count = syllables
128
+ current_target_idx += 1
129
+
130
+ if current_line and len(lines) < len(syllable_targets):
 
 
131
  lines.append(' '.join(current_line))
132
 
133
+ # Ensure we have exactly 3 lines for a haiku
134
+ while len(lines) < 3:
135
+ lines.append("...")
136
+
137
  return lines[:3]
138
 
139
  @staticmethod
 
141
  words = text.split()
142
  lines = []
143
  current_line = []
144
+ target_line_length = 10 # Approximate iambic pentameter
145
 
146
  for word in words:
147
  current_line.append(word)
 
149
  lines.append(' '.join(current_line))
150
  current_line = []
151
 
152
+ if len(lines) >= 14: # Traditional sonnet has 14 lines
153
  break
154
 
155
  if current_line and len(lines) < 14:
156
  lines.append(' '.join(current_line))
157
 
158
+ # Ensure we have 14 lines for a complete sonnet
159
+ while len(lines) < 14:
160
+ lines.append("...")
161
+
162
+ return lines
163
+
164
+ @staticmethod
165
+ def generate_title(poem_text: str) -> str:
166
+ words = poem_text.split()[:10] # Use more words to find better title candidates
167
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'with', 'by'}
168
+ key_words = [word for word in words if word.lower() not in stop_words and len(word) > 2]
169
+
170
+ if key_words:
171
+ title = ' '.join(key_words[:3]).strip().capitalize()
172
+ return title if title else "Untitled"
173
+ return "Untitled"
174
+
175
+ # Request queue for efficient processing
176
+ class RequestQueue:
177
+ def __init__(self, max_size=config.MAX_QUEUE_SIZE):
178
+ self.queue = asyncio.Queue(maxsize=max_size)
179
+ self.semaphore = asyncio.Semaphore(max_size)
180
+
181
+ async def add_request(self, request_data):
182
+ async with self.semaphore:
183
+ return await asyncio.wait_for(
184
+ self._process_request(request_data),
185
+ timeout=config.REQUEST_TIMEOUT
186
+ )
187
+
188
+ async def _process_request(self, request_data):
189
+ future = asyncio.Future()
190
+ await self.queue.put((request_data, future))
191
+ return await future
192
+
193
+ # Optimized Tokenization Service
194
+ class TokenizationService:
195
+ def __init__(self):
196
+ self.tokenizer = None
197
+ self._lock = asyncio.Lock()
198
+
199
+ @lru_cache(maxsize=config.CACHE_SIZE)
200
+ def cached_tokenize(self, text):
201
+ return self.tokenizer.encode(text, return_tensors='pt')
202
+
203
+ async def initialize(self):
204
+ async with self._lock:
205
+ if self.tokenizer is None:
206
+ logger.info("Initializing tokenizer")
207
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
208
+ self.tokenizer.pad_token = self.tokenizer.eos_token
209
+ return self.tokenizer
210
+
211
+ async def encode(self, text):
212
+ if not self.tokenizer:
213
+ await self.initialize()
214
+
215
+ # Use multithreading for tokenization if the text is large
216
+ if len(text) > 100:
217
+ loop = asyncio.get_event_loop()
218
+ return await loop.run_in_executor(
219
+ None,
220
+ lambda: self.cached_tokenize(text)
221
+ )
222
+ else:
223
+ return self.cached_tokenize(text)
224
+
225
+ def decode(self, tokens, skip_special_tokens=True):
226
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
227
 
228
+ # Model Manager with optimization techniques
229
  class ModelManager:
230
  def __init__(self):
231
  self.model = None
 
232
  self._lock = asyncio.Lock()
233
  self.request_count = 0
234
  self.last_cleanup = datetime.now()
235
+ self.model_ready = asyncio.Event()
236
+ self.tokenization_service = TokenizationService()
237
+ self.request_queue = RequestQueue()
238
  self.poem_formatter = PoemFormatter()
239
+ self.batch_processor_task = None
240
 
241
  async def initialize(self) -> bool:
242
  try:
243
+ logger.info(f"Initializing model on device: {config.DEVICE}")
 
 
244
 
245
+ await self.tokenization_service.initialize()
246
  await self._load_and_optimize_model()
247
 
248
+ # Start batch processing worker
249
+ self.batch_processor_task = asyncio.create_task(self._batch_processor_worker())
250
+
251
+ logger.info(f"Model and tokenizer loaded successfully on {config.DEVICE}")
252
+ self.model_ready.set()
253
+
254
+ # Warmup the model with dummy inputs
255
+ if config.WARMUP_INPUTS:
256
+ await self._warmup_model()
257
+
258
  return True
259
 
260
  except Exception as e:
 
262
  logger.exception("Detailed traceback:")
263
  return False
264
 
265
+ async def _batch_processor_worker(self):
266
+ """Worker that processes queued requests in batches"""
267
+ logger.info("Starting batch processor worker")
 
 
 
 
 
 
 
 
 
268
  try:
269
+ while True:
270
+ # Process requests in batches when possible
271
+ if not self.request_queue.queue.empty():
272
+ batch = []
273
+ batch_futures = []
274
+
275
+ # Get up to BATCH_SIZE requests from the queue
276
+ batch_size = min(config.BATCH_SIZE, self.request_queue.queue.qsize())
277
+ for _ in range(batch_size):
278
+ if self.request_queue.queue.empty():
279
+ break
280
+
281
+ request_data, future = await self.request_queue.queue.get()
282
+ batch.append(request_data)
283
+ batch_futures.append(future)
284
+
285
+ if batch:
286
+ try:
287
+ # Process the batch
288
+ results = await self._process_batch(batch)
289
+
290
+ # Set results to futures
291
+ for i, future in enumerate(batch_futures):
292
+ if not future.done():
293
+ future.set_result(results[i])
294
+ except Exception as e:
295
+ # Set exception to all futures in the batch
296
+ for future in batch_futures:
297
+ if not future.done():
298
+ future.set_exception(e)
299
+ finally:
300
+ # Mark tasks as done
301
+ for _ in range(len(batch)):
302
+ self.request_queue.queue.task_done()
303
+ else:
304
+ # If queue is empty, sleep briefly before checking again
305
+ await asyncio.sleep(0.01)
306
+
307
+ except asyncio.CancelledError:
308
+ logger.info("Batch processor worker cancelled")
309
  except Exception as e:
310
+ logger.error(f"Error in batch processor worker: {str(e)}")
311
+ logger.exception("Detailed traceback")
 
 
 
312
 
313
+ async def _process_batch(self, batch_requests):
314
+ """Process a batch of requests efficiently"""
315
+ results = []
316
+
317
+ # Use with torch.no_grad() for all requests in the batch
318
+ with torch.no_grad():
319
+ for request in batch_requests:
320
+ try:
321
+ # Prepare inputs
322
+ inputs = await self._prepare_inputs(request.prompt)
323
+
324
+ # Generate text
325
+ outputs = await self._generate_optimized(inputs, request)
326
+
327
+ # Process outputs
328
+ result = await self._process_outputs(outputs, request)
329
+ results.append(result)
330
+
331
+ except Exception as e:
332
+ logger.error(f"Error processing request in batch: {str(e)}")
333
+ results.append({"error": str(e)})
334
+
335
+ return results
336
 
337
  async def _load_and_optimize_model(self):
338
+ """Load and optimize the model with advanced techniques"""
339
+ async with self._lock:
340
+ if not os.path.exists(config.MODEL_PATH):
341
+ raise FileNotFoundError(f"Model file not found at {config.MODEL_PATH}")
342
+
343
+ # Create model with configuration
344
+ self.model = GPT2LMHeadModel(config.MODEL_CONFIG)
345
+
346
+ # Load state dict
347
+ state_dict = torch.load(config.MODEL_PATH, map_location=config.DEVICE)
348
+ self.model.load_state_dict(state_dict, strict=False)
349
+
350
+ # Move model to device
351
+ self.model.to(config.DEVICE)
352
+ self.model.eval() # Set to evaluation mode
353
+
354
+ # Apply quantization if enabled and supported
355
+ if config.QUANTIZE_MODEL and config.DEVICE.type == 'cuda':
356
+ try:
357
+ # Use dynamic quantization for better inference performance
358
+ torch.quantization.quantize_dynamic(
359
+ self.model, {torch.nn.Linear}, dtype=torch.qint8
360
+ )
361
+ logger.info("Model quantized successfully")
362
+ except Exception as e:
363
+ logger.warning(f"Quantization failed, using full precision: {str(e)}")
364
+
365
+ # Apply other optimizations for CUDA devices
366
+ if config.DEVICE.type == 'cuda':
367
+ # Set optimization flags
368
+ torch.backends.cudnn.benchmark = True
369
+ torch.backends.cuda.matmul.allow_tf32 = True
370
+
371
+ # Convert model to TorchScript for faster inference
372
+ try:
373
+ self.model = torch.jit.optimize_for_inference(
374
+ torch.jit.script(self.model)
375
+ )
376
+ logger.info("Model optimized with TorchScript")
377
+ except Exception as e:
378
+ logger.warning(f"TorchScript optimization failed: {str(e)}")
379
+
380
+ async def _warmup_model(self):
381
+ """Pre-warm the model with sample inputs to eliminate cold start issues"""
382
+ logger.info("Warming up model...")
383
 
384
+ # Create dummy inputs of different lengths
385
+ dummy_texts = [
386
+ "Write a poem about nature",
387
+ "Write a poem about love and loss in the modern world"
388
+ ]
389
 
390
+ # Process dummy requests
391
+ dummy_requests = [
392
+ GenerateRequest(prompt=text, max_length=50, temperature=0.9)
393
+ for text in dummy_texts
394
+ ]
395
 
396
+ for req in dummy_requests:
 
 
 
 
 
 
 
 
 
 
397
  try:
398
+ with torch.no_grad():
399
+ # Prepare inputs
400
+ inputs = await self._prepare_inputs(req.prompt)
401
+
402
+ # Run model inference
403
+ _ = await self._generate_optimized(inputs, req)
404
+
 
405
  except Exception as e:
406
+ logger.warning(f"Model warmup error: {str(e)}")
407
+
408
+ logger.info("Model warmup completed")
 
 
409
 
410
  async def _prepare_inputs(self, prompt: str):
411
+ """Prepare model inputs with optimized tokenization"""
412
  poetry_prompt = f"Write a poem about: {prompt}\n\nPoem:"
413
+ tokens = await self.tokenization_service.encode(poetry_prompt)
414
+ return tokens.to(config.DEVICE)
415
 
416
  async def _generate_optimized(self, inputs, request: GenerateRequest):
417
+ """Optimized text generation with style-specific parameters"""
418
+ attention_mask = torch.ones(inputs.shape, dtype=torch.long, device=config.DEVICE)
419
 
420
+ # Style-specific parameters
421
  style_params = {
422
+ "haiku": {"max_length": 50, "repetition_penalty": 1.4, "no_repeat_ngram_size": 2},
423
+ "sonnet": {"max_length": 200, "repetition_penalty": 1.2, "no_repeat_ngram_size": 3},
424
+ "free_verse": {
425
+ "max_length": request.max_length,
426
+ "repetition_penalty": request.repetition_penalty,
427
+ "no_repeat_ngram_size": 3
428
+ }
429
  }
430
 
431
  params = style_params.get(request.style, style_params["free_verse"])
432
 
433
+ # Get bad word IDs for filtering
434
+ tokenizer = await self.tokenization_service.initialize()
435
+ bad_words = ['http', 'www', 'com', ':', '/', '#', '[', ']', '{', '}']
436
+ bad_words_ids = [[tokenizer.encode(word)[0]] for word in bad_words if len(tokenizer.encode(word)) > 0]
437
+
438
  return self.model.generate(
439
  inputs,
440
  attention_mask=attention_mask,
 
445
  top_p=request.top_p,
446
  repetition_penalty=params["repetition_penalty"],
447
  do_sample=True,
448
+ pad_token_id=tokenizer.eos_token_id,
449
  use_cache=True,
450
+ no_repeat_ngram_size=params["no_repeat_ngram_size"],
451
  early_stopping=True,
452
+ bad_words_ids=bad_words_ids,
453
+ min_length=20 if request.style != "haiku" else 10,
 
454
  )
455
 
456
  async def _process_outputs(self, outputs, request: GenerateRequest):
457
+ """Process and format the generated text into a poem"""
458
+ # Decode generated text
459
+ raw_text = self.tokenization_service.decode(outputs[0], skip_special_tokens=True)
460
 
461
+ # Extract poem from generated text
462
  prompt_pattern = f"Write a poem about: {request.prompt}\n\nPoem:"
463
  poem_text = raw_text.replace(prompt_pattern, '').strip()
464
 
465
+ # Format based on style
466
  if request.style == "haiku":
467
+ formatted_lines = self.poem_formatter.format_haiku(poem_text)
468
  elif request.style == "sonnet":
469
+ formatted_lines = self.poem_formatter.format_sonnet(poem_text)
470
  else:
471
+ formatted_lines = self.poem_formatter.format_free_verse(poem_text)
472
 
473
+ # Generate response
474
  return {
475
  "poem": {
476
+ "title": self.poem_formatter.generate_title(poem_text),
477
  "lines": formatted_lines,
478
  "style": request.style
479
  },
 
486
  "repetition_penalty": request.repetition_penalty
487
  },
488
  "metadata": {
489
+ "device": config.DEVICE.type,
490
+ "model_type": "GPT2-Optimized",
491
  "timestamp": datetime.now().isoformat()
492
  }
493
  }
494
 
495
+ async def generate(self, request: GenerateRequest) -> Dict[str, Any]:
496
+ """Queue a request for generation and await result"""
497
+ try:
498
+ # Wait for model to be ready
499
+ await asyncio.wait_for(self.model_ready.wait(), timeout=60.0)
500
+
501
+ self.request_count += 1
502
+
503
+ # Add request to queue and get result
504
+ result = await self.request_queue.add_request(request)
505
+ return result
506
+
507
+ except asyncio.TimeoutError:
508
+ raise HTTPException(
509
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
510
+ detail="Model is still initializing or overloaded"
511
+ )
512
+ except Exception as e:
513
+ logger.error(f"Error generating text: {str(e)}")
514
+ raise HTTPException(
515
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
516
+ detail=str(e)
517
+ )
518
+
519
+ async def cleanup(self):
520
+ """Perform memory cleanup operations"""
521
+ if config.DEVICE.type == 'cuda':
522
+ torch.cuda.empty_cache()
523
+
524
+ self.last_cleanup = datetime.now()
525
+ logger.info("Memory cleanup performed")
526
+
527
+ async def shutdown(self):
528
+ """Clean shutdown of the model manager"""
529
+ # Cancel batch processor worker
530
+ if self.batch_processor_task:
531
+ self.batch_processor_task.cancel()
532
+ try:
533
+ await self.batch_processor_task
534
+ except asyncio.CancelledError:
535
+ pass
536
 
537
+ # Clear model from memory
538
+ if self.model is not None:
539
+ self.model = None
540
+
541
+ # Clear tokenizer from memory
542
+ if self.tokenization_service.tokenizer is not None:
543
+ self.tokenization_service.tokenizer = None
544
+
545
+ # Final memory cleanup
546
+ if config.DEVICE.type == 'cuda':
547
+ torch.cuda.empty_cache()
548
 
549
+ # Create model manager instance
550
+ model_manager = ModelManager()
 
 
 
551
 
552
+ # FastAPI lifespan
553
  @asynccontextmanager
554
  async def lifespan(app: FastAPI):
555
+ # Initialize on startup
556
+ initialized = await model_manager.initialize()
557
+ if not initialized:
558
  logger.error("Failed to initialize model manager")
559
+
560
  yield
561
+
562
+ # Clean up on shutdown
563
+ logger.info("Shutting down Poetry Generation API")
564
+ await model_manager.shutdown()
565
+
566
+ # Create FastAPI app
 
567
  app = FastAPI(
568
  title="Poetry Generation API",
569
+ description="High-Performance API for generating poetry using GPT-2",
570
+ version="3.0.0",
571
  lifespan=lifespan
572
  )
573
 
574
+ # Add CORS middleware
575
  app.add_middleware(
576
  CORSMiddleware,
577
  allow_origins=["*"],
 
580
  allow_headers=["*"],
581
  )
582
 
583
+ # Health check endpoint
 
584
  @app.api_route("/health", methods=["GET", "HEAD"])
585
  async def health_check():
586
  return {
587
  "status": "healthy",
588
  "model_loaded": model_manager.model is not None,
589
+ "model_ready": model_manager.model_ready.is_set(),
590
+ "tokenizer_loaded": model_manager.tokenization_service.tokenizer is not None,
591
+ "device": config.DEVICE.type,
592
  "request_count": model_manager.request_count,
593
+ "queue_size": model_manager.request_queue.queue.qsize(),
594
  "last_cleanup": model_manager.last_cleanup.isoformat(),
595
  "system_info": {
596
  "cuda_available": torch.cuda.is_available(),
597
  "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
598
+ "cuda_memory": {
599
+ "allocated": f"{torch.cuda.memory_allocated() / (1024**2):.2f} MB",
600
+ "reserved": f"{torch.cuda.memory_reserved() / (1024**2):.2f} MB",
601
+ "max_allocated": f"{torch.cuda.max_memory_allocated() / (1024**2):.2f} MB"
602
+ } if torch.cuda.is_available() else {},
603
  }
604
  }
605
 
606
+ # Poetry generation endpoint
607
  @app.post("/generate")
608
  async def generate_text(
609
  request: GenerateRequest,
 
612
  try:
613
  result = await model_manager.generate(request)
614
 
615
+ # Schedule cleanup every 50 requests
616
+ if model_manager.request_count % 50 == 0:
617
+ background_tasks.add_task(model_manager.cleanup)
618
 
619
  return JSONResponse(
620
  content=result,
621
  status_code=status.HTTP_200_OK
622
  )
623
+ except HTTPException as e:
624
+ # Re-raise HTTP exceptions
625
+ raise
626
  except Exception as e:
627
  logger.error(f"Error in generate_text: {str(e)}")
628
  raise HTTPException(
629
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
630
  detail=str(e)
631
+ )
632
+
633
+ # Add profiling endpoint if profiling is enabled
634
+ if config.ENABLE_PROFILING:
635
+ @app.get("/profiling")
636
+ async def get_profiling():
637
+ if config.DEVICE.type == 'cuda':
638
+ return {
639
+ "memory": {
640
+ "allocated": f"{torch.cuda.memory_allocated() / (1024**2):.2f} MB",
641
+ "reserved": f"{torch.cuda.memory_reserved() / (1024**2):.2f} MB",
642
+ "max_allocated": f"{torch.cuda.max_memory_allocated() / (1024**2):.2f} MB"
643
+ },
644
+ "request_count": model_manager.request_count,
645
+ "queue_size": model_manager.request_queue.queue.qsize(),
646
+ }
647
+ else:
648
+ return {"device": "cpu", "profiling": "not available"}