Spaces:
Running
Running
| """ | |
| PlasmidGPT HuggingFace Space Deployment | |
| This Space loads the PlasmidGPT model and exposes it as a FastAPI service | |
| that can be called from your Render backend. | |
| """ | |
| import os | |
| import logging | |
| from typing import Dict, Any, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| import torch | |
| from transformers import AutoTokenizer, AutoConfig, GenerationConfig | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| import time | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="PlasmidGPT API", | |
| description="PlasmidGPT model API for DNA sequence generation", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS for Render backend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, restrict to your Render URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model and tokenizer | |
| model = None | |
| tokenizer = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def _manual_generate(model, input_ids, request): | |
| """ | |
| Manual generation for models without generate() method. | |
| Simple greedy/random sampling implementation. | |
| """ | |
| model.eval() | |
| generated = input_ids.clone() | |
| for _ in range(request.max_length - input_ids.shape[1]): | |
| with torch.no_grad(): | |
| outputs = model(generated) | |
| # Get logits from last position | |
| if isinstance(outputs, tuple): | |
| logits = outputs[0][:, -1, :] | |
| else: | |
| logits = outputs[:, -1, :] | |
| # Apply temperature | |
| if request.temperature != 1.0: | |
| logits = logits / request.temperature | |
| # Sample next token | |
| if request.do_sample: | |
| probs = torch.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, 1) | |
| else: | |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
| # Append to generated sequence | |
| generated = torch.cat([generated, next_token], dim=1) | |
| # Check for EOS token (if tokenizer has one) | |
| if hasattr(tokenizer, 'eos_token_id') and next_token.item() == tokenizer.eos_token_id: | |
| break | |
| return generated | |
| # Request/Response models | |
| class GenerationRequest(BaseModel): | |
| prompt: str = Field(..., description="DNA sequence prompt or seed") | |
| max_length: int = Field(100, ge=10, le=1000, description="Maximum sequence length") | |
| temperature: float = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") | |
| num_return_sequences: int = Field(1, ge=1, le=3, description="Number of sequences to generate") | |
| do_sample: bool = Field(True, description="Whether to use sampling") | |
| repetition_penalty: float = Field(1.1, ge=1.0, le=2.0, description="Repetition penalty") | |
| class GenerationResponse(BaseModel): | |
| sequences: list[str] | |
| metadata: Dict[str, Any] | |
| generation_time: float | |
| class HealthResponse(BaseModel): | |
| model_config = {"protected_namespaces": ()} # Fix Pydantic warnings for model_* fields | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| model_name: str | |
| async def load_model(): | |
| """ | |
| Load PlasmidGPT custom PyTorch model on startup. | |
| PlasmidGPT is NOT a standard transformers model - it's a custom PyTorch model | |
| that needs to be loaded with torch.load() and uses a custom tokenizer. | |
| """ | |
| global model, tokenizer | |
| logger.info("Loading PlasmidGPT custom model...") | |
| logger.info(f"Using device: {device}") | |
| try: | |
| model_name = "lingxusb/PlasmidGPT" | |
| # Download custom tokenizer file | |
| logger.info("Downloading custom tokenizer...") | |
| tokenizer_path = hf_hub_download( | |
| repo_id=model_name, | |
| filename="addgene_trained_dna_tokenizer.json", | |
| cache_dir="/tmp/hf_cache" | |
| ) | |
| # Load custom tokenizer | |
| logger.info("Loading custom tokenizer...") | |
| from tokenizers import Tokenizer | |
| tokenizer = Tokenizer.from_file(tokenizer_path) | |
| # Download model file | |
| logger.info("Downloading model file (this may take a few minutes)...") | |
| model_path = hf_hub_download( | |
| repo_id=model_name, | |
| filename="pretrained_model.pt", | |
| cache_dir="/tmp/hf_cache" | |
| ) | |
| # Load custom PyTorch model | |
| # Note: PyTorch 2.6+ requires weights_only=False for models with custom classes | |
| # This is safe since the model is from HuggingFace (trusted source) | |
| logger.info("Loading custom PyTorch model...") | |
| # Allowlist GPT2LMHeadModel class (if supported by PyTorch version) | |
| if hasattr(torch.serialization, 'add_safe_globals'): | |
| from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel | |
| torch.serialization.add_safe_globals([GPT2LMHeadModel]) | |
| # Load with weights_only=False (safe for HuggingFace models) | |
| model = torch.load(model_path, map_location=device, weights_only=False) | |
| model = model.to(device) | |
| model.eval() | |
| logger.info("✅ PlasmidGPT model loaded successfully!") | |
| logger.info(f"Model device: {next(model.parameters()).device}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| logger.error(f"Error type: {type(e).__name__}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| raise | |
| async def root(): | |
| """Health check endpoint.""" | |
| return HealthResponse( | |
| status="healthy" if model is not None else "loading", | |
| model_loaded=model is not None, | |
| device=device, | |
| model_name="lingxusb/PlasmidGPT" | |
| ) | |
| async def health(): | |
| """Health check endpoint.""" | |
| return HealthResponse( | |
| status="healthy" if model is not None else "loading", | |
| model_loaded=model is not None, | |
| device=device, | |
| model_name="lingxusb/PlasmidGPT" | |
| ) | |
| async def generate_sequences(request: GenerationRequest): | |
| """ | |
| Generate DNA sequences using PlasmidGPT. | |
| Args: | |
| request: Generation parameters | |
| Returns: | |
| Generated sequences with metadata | |
| """ | |
| if model is None or tokenizer is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is still loading. Please wait and try again." | |
| ) | |
| try: | |
| start_time = time.time() | |
| # Tokenize input using custom tokenizer | |
| # Custom tokenizer uses encode() method (returns list, not tensor) | |
| encoded = tokenizer.encode(request.prompt) | |
| input_ids = torch.tensor([encoded.ids], dtype=torch.long).to(device) | |
| # Generate sequences using custom model | |
| # PlasmidGPT model has custom generate() method | |
| with torch.no_grad(): | |
| # Check if model has generate method or needs custom generation | |
| if hasattr(model, 'generate'): | |
| # Try to use model's generate method with GenerationConfig | |
| try: | |
| generation_config = GenerationConfig.from_model_config(model.config) if hasattr(model, 'config') else None | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=request.max_length, | |
| num_return_sequences=request.num_return_sequences, | |
| temperature=request.temperature, | |
| do_sample=request.do_sample, | |
| generation_config=generation_config | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Model.generate() failed: {e}, trying manual generation") | |
| # Fallback to manual generation if generate() doesn't work | |
| outputs = _manual_generate(model, input_ids, request) | |
| else: | |
| # Manual generation if model doesn't have generate method | |
| outputs = _manual_generate(model, input_ids, request) | |
| # Decode sequences | |
| sequences = [] | |
| for output in outputs: | |
| # Decode only the generated part (exclude prompt) | |
| generated = output[input_ids.shape[1]:].cpu().tolist()[0] | |
| # Custom tokenizer decode expects list of token IDs | |
| decoded = tokenizer.decode(generated, skip_special_tokens=True) | |
| sequences.append(decoded) | |
| generation_time = time.time() - start_time | |
| return GenerationResponse( | |
| sequences=sequences, | |
| metadata={ | |
| "prompt": request.prompt, | |
| "prompt_length": len(request.prompt), | |
| "generated_lengths": [len(seq) for seq in sequences], | |
| "device": device, | |
| "model": "lingxusb/PlasmidGPT" | |
| }, | |
| generation_time=generation_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Generation failed: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Generation failed: {str(e)}" | |
| ) | |
| async def extract_embeddings(request: Dict[str, Any]): | |
| """ | |
| Extract embeddings from sequences (placeholder - implement if needed). | |
| """ | |
| raise HTTPException( | |
| status_code=501, | |
| detail="Embedding extraction not yet implemented in Space deployment" | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |