kingkulk commited on
Commit
e24512b
·
verified ·
1 Parent(s): 438f2de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -46
app.py CHANGED
@@ -160,56 +160,27 @@ async def load_model():
160
  model = model.to(device)
161
  model.eval()
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  logger.info("✅ PlasmidGPT model loaded successfully!")
164
 
165
  except Exception as e:
166
- logger.error(f"Failed to load model: {str(e)}")
167
- logger.error(f"Error type: {type(e).__name__}")
168
- import traceback
169
- logger.error(traceback.format_exc())
170
- raise
171
- @app.get("/", response_model=HealthResponse)
172
- async def root():
173
- """Health check endpoint."""
174
- return HealthResponse(
175
- status="healthy" if model is not None else "loading",
176
- model_loaded=model is not None,
177
- device=device,
178
- model_name="lingxusb/PlasmidGPT"
179
- )
180
- @app.get("/health", response_model=HealthResponse)
181
- async def health():
182
- """Health check endpoint."""
183
- return HealthResponse(
184
- status="healthy" if model is not None else "loading",
185
- model_loaded=model is not None,
186
- device=device,
187
- model_name="lingxusb/PlasmidGPT"
188
- )
189
- @app.post("/generate", response_model=GenerationResponse)
190
- async def generate_sequences(request: GenerationRequest):
191
- """
192
- Generate DNA sequences using PlasmidGPT.
193
-
194
- Args:
195
- request: Generation parameters
196
-
197
- Returns:
198
- Generated sequences with metadata
199
- """
200
- if model is None or tokenizer is None:
201
- raise HTTPException(
202
- status_code=503,
203
- detail="Model is still loading. Please wait and try again."
204
- )
205
-
206
- try:
207
- start_time = time.time()
208
-
209
- # Tokenize input using custom tokenizer
210
- # Custom tokenizer uses encode() method (returns list, not tensor)
211
  encoded = tokenizer.encode(request.prompt)
212
  input_ids = torch.tensor([encoded.ids], dtype=torch.long).to(device)
 
213
 
214
  # Generate sequences using custom model
215
  # PlasmidGPT model has custom generate() method
@@ -232,7 +203,9 @@ async def generate_sequences(request: GenerationRequest):
232
  "max_length": request.max_length,
233
  "num_return_sequences": request.num_return_sequences,
234
  "temperature": request.temperature,
235
- "do_sample": request.do_sample
 
 
236
  }
237
  if generation_config:
238
  gen_kwargs["generation_config"] = generation_config
 
160
  model = model.to(device)
161
  model.eval()
162
 
163
+ # Patch model config for compatibility with newer transformers
164
+ if hasattr(model, 'config'):
165
+ # Ensure _output_attentions exists (fixes AttributeError)
166
+ if not hasattr(model.config, '_output_attentions'):
167
+ setattr(model.config, '_output_attentions', False)
168
+
169
+ # Ensure output_attentions property uses the attribute
170
+ if not hasattr(model.config, 'output_attentions'):
171
+ model.config.output_attentions = False
172
+
173
+ # Ensure other common missing attributes
174
+ if not hasattr(model.config, 'return_dict'):
175
+ model.config.return_dict = True
176
+
177
  logger.info("✅ PlasmidGPT model loaded successfully!")
178
 
179
  except Exception as e:
180
+ # ... (skip to generate_sequences)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  encoded = tokenizer.encode(request.prompt)
182
  input_ids = torch.tensor([encoded.ids], dtype=torch.long).to(device)
183
+ attention_mask = torch.ones_like(input_ids).to(device)
184
 
185
  # Generate sequences using custom model
186
  # PlasmidGPT model has custom generate() method
 
203
  "max_length": request.max_length,
204
  "num_return_sequences": request.num_return_sequences,
205
  "temperature": request.temperature,
206
+ "do_sample": request.do_sample,
207
+ "attention_mask": attention_mask,
208
+ "pad_token_id": tokenizer.eos_token_id if hasattr(tokenizer, 'eos_token_id') else 50256
209
  }
210
  if generation_config:
211
  gen_kwargs["generation_config"] = generation_config