""" PlasmidGPT HuggingFace Space Deployment This Space loads the PlasmidGPT model and exposes it as a FastAPI service that can be called from your Render backend. """ import os import logging from typing import Dict, Any, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import torch from transformers import AutoTokenizer, AutoConfig, GenerationConfig from huggingface_hub import hf_hub_download import json import time # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="PlasmidGPT API", description="PlasmidGPT model API for DNA sequence generation", version="1.0.0" ) # Enable CORS for Render backend app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, restrict to your Render URL allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global model and tokenizer model = None tokenizer = None device = "cuda" if torch.cuda.is_available() else "cpu" def _manual_generate(model, input_ids, request): """ Manual generation for models without generate() method. Simple greedy/random sampling implementation. """ model.eval() generated = input_ids.clone() for _ in range(request.max_length - input_ids.shape[1]): with torch.no_grad(): outputs = model(generated) # Get logits from last position if isinstance(outputs, tuple): logits = outputs[0][:, -1, :] else: logits = outputs[:, -1, :] # Apply temperature if request.temperature != 1.0: logits = logits / request.temperature # Sample next token if request.do_sample: probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, 1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) # Append to generated sequence generated = torch.cat([generated, next_token], dim=1) # Check for EOS token (if tokenizer has one) if hasattr(tokenizer, 'eos_token_id') and next_token.item() == tokenizer.eos_token_id: break return generated # Request/Response models class GenerationRequest(BaseModel): prompt: str = Field(..., description="DNA sequence prompt or seed") max_length: int = Field(100, ge=10, le=1000, description="Maximum sequence length") temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") num_return_sequences: int = Field(1, ge=1, le=3, description="Number of sequences to generate") do_sample: bool = Field(True, description="Whether to use sampling") repetition_penalty: float = Field(1.1, ge=1.0, le=2.0, description="Repetition penalty") class GenerationResponse(BaseModel): sequences: list[str] metadata: Dict[str, Any] generation_time: float class HealthResponse(BaseModel): model_config = {"protected_namespaces": ()} # Fix Pydantic warnings for model_* fields status: str model_loaded: bool device: str model_name: str @app.on_event("startup") async def load_model(): """ Load PlasmidGPT custom PyTorch model on startup. PlasmidGPT is NOT a standard transformers model - it's a custom PyTorch model that needs to be loaded with torch.load() and uses a custom tokenizer. """ global model, tokenizer logger.info("Loading PlasmidGPT custom model...") logger.info(f"Using device: {device}") try: model_name = "lingxusb/PlasmidGPT" # Download custom tokenizer file logger.info("Downloading custom tokenizer...") tokenizer_path = hf_hub_download( repo_id=model_name, filename="addgene_trained_dna_tokenizer.json", cache_dir="/tmp/hf_cache" ) # Load custom tokenizer logger.info("Loading custom tokenizer...") from tokenizers import Tokenizer tokenizer = Tokenizer.from_file(tokenizer_path) # Download model file logger.info("Downloading model file (this may take a few minutes)...") model_path = hf_hub_download( repo_id=model_name, filename="pretrained_model.pt", cache_dir="/tmp/hf_cache" ) # Load custom PyTorch model # Note: PyTorch 2.6+ requires weights_only=False for models with custom classes # This is safe since the model is from HuggingFace (trusted source) logger.info("Loading custom PyTorch model...") # Allowlist GPT2LMHeadModel class (if supported by PyTorch version) if hasattr(torch.serialization, 'add_safe_globals'): from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel torch.serialization.add_safe_globals([GPT2LMHeadModel]) # Load with weights_only=False (safe for HuggingFace models) model = torch.load(model_path, map_location=device, weights_only=False) model = model.to(device) model.eval() logger.info("✅ PlasmidGPT model loaded successfully!") logger.info(f"Model device: {next(model.parameters()).device}") except Exception as e: logger.error(f"Failed to load model: {str(e)}") logger.error(f"Error type: {type(e).__name__}") import traceback logger.error(traceback.format_exc()) raise @app.get("/", response_model=HealthResponse) async def root(): """Health check endpoint.""" return HealthResponse( status="healthy" if model is not None else "loading", model_loaded=model is not None, device=device, model_name="lingxusb/PlasmidGPT" ) @app.get("/health", response_model=HealthResponse) async def health(): """Health check endpoint.""" return HealthResponse( status="healthy" if model is not None else "loading", model_loaded=model is not None, device=device, model_name="lingxusb/PlasmidGPT" ) @app.post("/generate", response_model=GenerationResponse) async def generate_sequences(request: GenerationRequest): """ Generate DNA sequences using PlasmidGPT. Args: request: Generation parameters Returns: Generated sequences with metadata """ if model is None or tokenizer is None: raise HTTPException( status_code=503, detail="Model is still loading. Please wait and try again." ) try: start_time = time.time() # Tokenize input using custom tokenizer # Custom tokenizer uses encode() method (returns list, not tensor) encoded = tokenizer.encode(request.prompt) input_ids = torch.tensor([encoded.ids], dtype=torch.long).to(device) # Generate sequences using custom model # PlasmidGPT model has custom generate() method with torch.no_grad(): # Check if model has generate method or needs custom generation if hasattr(model, 'generate'): # Try to use model's generate method with GenerationConfig try: generation_config = GenerationConfig.from_model_config(model.config) if hasattr(model, 'config') else None outputs = model.generate( input_ids, max_length=request.max_length, num_return_sequences=request.num_return_sequences, temperature=request.temperature, do_sample=request.do_sample, generation_config=generation_config ) except Exception as e: logger.warning(f"Model.generate() failed: {e}, trying manual generation") # Fallback to manual generation if generate() doesn't work outputs = _manual_generate(model, input_ids, request) else: # Manual generation if model doesn't have generate method outputs = _manual_generate(model, input_ids, request) # Decode sequences sequences = [] for output in outputs: # Decode only the generated part (exclude prompt) generated = output[input_ids.shape[1]:].cpu().tolist()[0] # Custom tokenizer decode expects list of token IDs decoded = tokenizer.decode(generated, skip_special_tokens=True) sequences.append(decoded) generation_time = time.time() - start_time return GenerationResponse( sequences=sequences, metadata={ "prompt": request.prompt, "prompt_length": len(request.prompt), "generated_lengths": [len(seq) for seq in sequences], "device": device, "model": "lingxusb/PlasmidGPT" }, generation_time=generation_time ) except Exception as e: logger.error(f"Generation failed: {str(e)}") raise HTTPException( status_code=500, detail=f"Generation failed: {str(e)}" ) @app.post("/embed") async def extract_embeddings(request: Dict[str, Any]): """ Extract embeddings from sequences (placeholder - implement if needed). """ raise HTTPException( status_code=501, detail="Embedding extraction not yet implemented in Space deployment" ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)