Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import os | |
| import tempfile | |
| import time | |
| import polars as pl | |
| import numpy as np | |
| from pathlib import Path | |
| from omegaconf import OmegaConf, DictConfig | |
| # --- InstaNovo Imports --- | |
| # It's good practice to handle potential import issues | |
| try: | |
| from instanovo.transformer.model import InstaNovo | |
| from instanovo.utils import SpectrumDataFrame, ResidueSet, Metrics | |
| from instanovo.transformer.dataset import SpectrumDataset, collate_batch | |
| from instanovo.inference import ( | |
| GreedyDecoder, | |
| KnapsackBeamSearchDecoder, | |
| Knapsack, | |
| ScoredSequence, | |
| Decoder, | |
| ) | |
| from instanovo.constants import MASS_SCALE, MAX_MASS | |
| from torch.utils.data import DataLoader | |
| except ImportError as e: | |
| print(f"Error importing InstaNovo components: {e}") | |
| print("Please ensure InstaNovo is installed correctly.") | |
| # Optionally, raise the error or exit if InstaNovo is critical | |
| # raise e | |
| # --- Configuration --- | |
| MODEL_ID = "instanovo-v1.1.0" # Use the desired pretrained model ID | |
| KNAPSACK_DIR = Path("./knapsack_cache") | |
| DEFAULT_CONFIG_PATH = Path("./configs/inference/default.yaml") # Assuming instanovo installs configs locally relative to execution | |
| # Determine device | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| FP16 = DEVICE == "cuda" # Enable FP16 only on CUDA | |
| # --- Global Variables (Load Model and Knapsack Once) --- | |
| MODEL: InstaNovo | None = None | |
| KNAPSACK: Knapsack | None = None | |
| MODEL_CONFIG: DictConfig | None = None | |
| RESIDUE_SET: ResidueSet | None = None | |
| def load_model_and_knapsack(): | |
| """Loads the InstaNovo model and generates/loads the knapsack.""" | |
| global MODEL, KNAPSACK, MODEL_CONFIG, RESIDUE_SET | |
| if MODEL is not None: | |
| print("Model already loaded.") | |
| return | |
| print(f"Loading InstaNovo model: {MODEL_ID} to {DEVICE}...") | |
| try: | |
| MODEL, MODEL_CONFIG = InstaNovo.from_pretrained(MODEL_ID) | |
| MODEL.to(DEVICE) | |
| MODEL.eval() | |
| RESIDUE_SET = MODEL.residue_set | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise gr.Error(f"Failed to load InstaNovo model: {MODEL_ID}. Error: {e}") | |
| # --- Knapsack Handling --- | |
| KNAPSACK_DIR.mkdir(parents=True, exist_ok=True) | |
| knapsack_exists = ( | |
| (KNAPSACK_DIR / "parameters.pkl").exists() and | |
| (KNAPSACK_DIR / "masses.npy").exists() and | |
| (KNAPSACK_DIR / "chart.npy").exists() | |
| ) | |
| if knapsack_exists: | |
| print(f"Loading pre-generated knapsack from {KNAPSACK_DIR}...") | |
| try: | |
| KNAPSACK = Knapsack.from_file(str(KNAPSACK_DIR)) | |
| print("Knapsack loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading knapsack: {e}. Will attempt to regenerate.") | |
| KNAPSACK = None # Force regeneration | |
| knapsack_exists = False # Ensure generation happens | |
| if not knapsack_exists: | |
| print("Knapsack not found or failed to load. Generating knapsack...") | |
| if RESIDUE_SET is None: | |
| raise gr.Error("Cannot generate knapsack because ResidueSet failed to load.") | |
| try: | |
| # Prepare residue masses for knapsack generation (handle negative/zero masses) | |
| residue_masses_knapsack = dict(RESIDUE_SET.residue_masses.copy()) | |
| negative_residues = [k for k, v in residue_masses_knapsack.items() if v <= 0] | |
| if negative_residues: | |
| print(f"Warning: Non-positive masses found in residues: {negative_residues}. " | |
| "Excluding from knapsack generation.") | |
| for res in negative_residues: | |
| del residue_masses_knapsack[res] | |
| # Remove special tokens explicitly if they somehow got mass | |
| for special_token in RESIDUE_SET.special_tokens: | |
| if special_token in residue_masses_knapsack: | |
| del residue_masses_knapsack[special_token] | |
| # Ensure residue indices used match those without special/negative masses | |
| valid_residue_indices = { | |
| res: idx for res, idx in RESIDUE_SET.residue_to_index.items() | |
| if res in residue_masses_knapsack | |
| } | |
| KNAPSACK = Knapsack.construct_knapsack( | |
| residue_masses=residue_masses_knapsack, | |
| residue_indices=valid_residue_indices, # Use only valid indices | |
| max_mass=MAX_MASS, | |
| mass_scale=MASS_SCALE, | |
| ) | |
| print(f"Knapsack generated. Saving to {KNAPSACK_DIR}...") | |
| KNAPSACK.save(str(KNAPSACK_DIR)) # Save for future runs | |
| print("Knapsack saved.") | |
| except Exception as e: | |
| print(f"Error generating or saving knapsack: {e}") | |
| gr.Warning("Failed to generate Knapsack. Knapsack Beam Search will not be available.") | |
| KNAPSACK = None # Ensure it's None if generation failed | |
| # Load the model and knapsack when the script starts | |
| load_model_and_knapsack() | |
| def create_inference_config( | |
| input_path: str, | |
| output_path: str, | |
| decoding_method: str, | |
| ) -> DictConfig: | |
| """Creates the OmegaConf DictConfig needed for prediction.""" | |
| # Load default config if available, otherwise create from scratch | |
| if DEFAULT_CONFIG_PATH.exists(): | |
| base_cfg = OmegaConf.load(DEFAULT_CONFIG_PATH) | |
| else: | |
| print(f"Warning: Default config not found at {DEFAULT_CONFIG_PATH}. Using minimal config.") | |
| # Create a minimal config if default is missing | |
| base_cfg = OmegaConf.create({ | |
| "data_path": None, | |
| "instanovo_model": MODEL_ID, | |
| "output_path": None, | |
| "knapsack_path": str(KNAPSACK_DIR), | |
| "denovo": True, | |
| "refine": False, # Not doing refinement here | |
| "num_beams": 1, | |
| "max_length": 40, | |
| "max_charge": 10, | |
| "isotope_error_range": [0, 1], | |
| "subset": 1.0, | |
| "use_knapsack": False, | |
| "save_beams": False, | |
| "batch_size": 64, # Adjust as needed | |
| "device": DEVICE, | |
| "fp16": FP16, | |
| "log_interval": 500, # Less relevant for Gradio app | |
| "use_basic_logging": True, | |
| "filter_precursor_ppm": 20, | |
| "filter_confidence": 1e-4, | |
| "filter_fdr_threshold": 0.05, | |
| "residue_remapping": { # Add default mappings | |
| "M(ox)": "M[UNIMOD:35]", "M(+15.99)": "M[UNIMOD:35]", | |
| "S(p)": "S[UNIMOD:21]", "T(p)": "T[UNIMOD:21]", "Y(p)": "Y[UNIMOD:21]", | |
| "S(+79.97)": "S[UNIMOD:21]", "T(+79.97)": "T[UNIMOD:21]", "Y(+79.97)": "Y[UNIMOD:21]", | |
| "Q(+0.98)": "Q[UNIMOD:7]", "N(+0.98)": "N[UNIMOD:7]", | |
| "Q(+.98)": "Q[UNIMOD:7]", "N(+.98)": "N[UNIMOD:7]", | |
| "C(+57.02)": "C[UNIMOD:4]", | |
| "(+42.01)": "[UNIMOD:1]", "(+43.01)": "[UNIMOD:5]", "(-17.03)": "[UNIMOD:385]", | |
| }, | |
| "column_map": { # Add default mappings | |
| "Modified sequence": "modified_sequence", "MS/MS m/z": "precursor_mz", | |
| "Mass": "precursor_mass", "Charge": "precursor_charge", | |
| "Mass values": "mz_array", "Mass spectrum": "mz_array", | |
| "Intensity": "intensity_array", "Raw intensity spectrum": "intensity_array", | |
| "Scan number": "scan_number" | |
| }, | |
| "index_columns": [ | |
| "scan_number", "precursor_mz", "precursor_charge", | |
| ], | |
| # Add other defaults if needed based on errors | |
| }) | |
| # Override specific parameters | |
| cfg_overrides = { | |
| "data_path": input_path, | |
| "output_path": output_path, | |
| "device": DEVICE, | |
| "fp16": FP16, | |
| "denovo": True, | |
| "refine": False, | |
| } | |
| if "Greedy" in decoding_method: | |
| cfg_overrides["num_beams"] = 1 | |
| cfg_overrides["use_knapsack"] = False | |
| elif "Knapsack" in decoding_method: | |
| if KNAPSACK is None: | |
| raise gr.Error("Knapsack is not available. Cannot use Knapsack Beam Search.") | |
| cfg_overrides["num_beams"] = 5 | |
| cfg_overrides["use_knapsack"] = True | |
| cfg_overrides["knapsack_path"] = str(KNAPSACK_DIR) | |
| else: | |
| raise ValueError(f"Unknown decoding method: {decoding_method}") | |
| # Merge base config with overrides | |
| final_cfg = OmegaConf.merge(base_cfg, cfg_overrides) | |
| return final_cfg | |
| def predict_peptides(input_file, decoding_method): | |
| """ | |
| Main function to load data, run prediction, and return results. | |
| """ | |
| if MODEL is None or RESIDUE_SET is None or MODEL_CONFIG is None: | |
| load_model_and_knapsack() # Attempt to reload if None (e.g., after space restart) | |
| if MODEL is None: | |
| raise gr.Error("InstaNovo model is not loaded. Cannot perform prediction.") | |
| if input_file is None: | |
| raise gr.Error("Please upload a mass spectrometry file.") | |
| input_path = input_file.name # Gradio provides the path in .name | |
| print(f"Processing file: {input_path}") | |
| print(f"Using decoding method: {decoding_method}") | |
| # Create a temporary file for the output CSV | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_out: | |
| output_csv_path = temp_out.name | |
| try: | |
| # 1. Create Config | |
| config = create_inference_config(input_path, output_csv_path, decoding_method) | |
| print("Inference Config:\n", OmegaConf.to_yaml(config)) | |
| # 2. Load Data using SpectrumDataFrame | |
| print("Loading spectrum data...") | |
| try: | |
| sdf = SpectrumDataFrame.load( | |
| config.data_path, | |
| lazy=False, # Load eagerly for Gradio simplicity | |
| is_annotated=False, # De novo mode | |
| column_mapping=config.get("column_map", None), | |
| shuffle=False, | |
| verbose=True # Print loading logs | |
| ) | |
| # Apply charge filter like in CLI | |
| original_size = len(sdf) | |
| max_charge = config.get("max_charge", 10) | |
| sdf.filter_rows( | |
| lambda row: (row["precursor_charge"] <= max_charge) and (row["precursor_charge"] > 0) | |
| ) | |
| if len(sdf) < original_size: | |
| print(f"Warning: Filtered {original_size - len(sdf)} spectra with charge > {max_charge} or <= 0.") | |
| if len(sdf) == 0: | |
| raise gr.Error("No valid spectra found in the uploaded file after filtering.") | |
| print(f"Data loaded: {len(sdf)} spectra.") | |
| except Exception as e: | |
| print(f"Error loading data: {e}") | |
| raise gr.Error(f"Failed to load or process the spectrum file. Error: {e}") | |
| # 3. Prepare Dataset and DataLoader | |
| ds = SpectrumDataset( | |
| sdf, | |
| RESIDUE_SET, | |
| MODEL_CONFIG.get("n_peaks", 200), | |
| return_str=True, # Needed for greedy/beam search targets later (though not used here) | |
| annotated=False, | |
| pad_spectrum_max_length=config.get("compile_model", False) or config.get("use_flash_attention", False), | |
| bin_spectra=config.get("conv_peak_encoder", False), | |
| ) | |
| dl = DataLoader( | |
| ds, | |
| batch_size=config.batch_size, | |
| num_workers=0, # Required by SpectrumDataFrame | |
| shuffle=False, # Required by SpectrumDataFrame | |
| collate_fn=collate_batch, | |
| ) | |
| # 4. Select Decoder | |
| print("Initializing decoder...") | |
| decoder: Decoder | |
| if config.use_knapsack: | |
| if KNAPSACK is None: | |
| # This check should ideally be earlier, but double-check | |
| raise gr.Error("Knapsack is required for Knapsack Beam Search but is not available.") | |
| # KnapsackBeamSearchDecoder doesn't directly load from path in this version? | |
| # We load Knapsack globally, so just pass it. | |
| # If it needed path: decoder = KnapsackBeamSearchDecoder.from_file(model=MODEL, path=config.knapsack_path) | |
| decoder = KnapsackBeamSearchDecoder(model=MODEL, knapsack=KNAPSACK) | |
| elif config.num_beams > 1: | |
| # BeamSearchDecoder is available but not explicitly requested, use Greedy for num_beams=1 | |
| print(f"Warning: num_beams={config.num_beams} > 1 but only Greedy and Knapsack Beam Search are implemented in this app. Defaulting to Greedy.") | |
| decoder = GreedyDecoder(model=MODEL, mass_scale=MASS_SCALE) | |
| else: | |
| decoder = GreedyDecoder( | |
| model=MODEL, | |
| mass_scale=MASS_SCALE, | |
| # Add suppression options if needed from config | |
| suppressed_residues=config.get("suppressed_residues", None), | |
| disable_terminal_residues_anywhere=config.get("disable_terminal_residues_anywhere", True), | |
| ) | |
| print(f"Using decoder: {type(decoder).__name__}") | |
| # 5. Run Prediction Loop (Adapted from instanovo/transformer/predict.py) | |
| print("Starting prediction...") | |
| start_time = time.time() | |
| results_list: list[ScoredSequence | list] = [] # Store ScoredSequence or empty list | |
| for i, batch in enumerate(dl): | |
| spectra, precursors, spectra_mask, _, _ = batch # Ignore peptides/masks for de novo | |
| spectra = spectra.to(DEVICE) | |
| precursors = precursors.to(DEVICE) | |
| spectra_mask = spectra_mask.to(DEVICE) | |
| with torch.no_grad(), torch.amp.autocast(DEVICE, dtype=torch.float16, enabled=FP16): | |
| # Beam search decoder might return list[list[ScoredSequence]] if return_beam=True | |
| # Greedy decoder returns list[ScoredSequence] | |
| # KnapsackBeamSearchDecoder returns list[ScoredSequence] or list[list[ScoredSequence]] | |
| batch_predictions = decoder.decode( | |
| spectra=spectra, | |
| precursors=precursors, | |
| beam_size=config.num_beams, | |
| max_length=config.max_length, | |
| # Knapsack/Beam Search specific params if needed | |
| mass_tolerance=config.get("filter_precursor_ppm", 20) * 1e-6, # Convert ppm to relative | |
| max_isotope=config.isotope_error_range[1] if config.isotope_error_range else 1, | |
| return_beam=False # Only get the top prediction for simplicity | |
| ) | |
| results_list.extend(batch_predictions) # Should be list[ScoredSequence] or list[list] | |
| print(f"Processed batch {i+1}/{len(dl)}") | |
| end_time = time.time() | |
| print(f"Prediction finished in {end_time - start_time:.2f} seconds.") | |
| # 6. Format Results | |
| print("Formatting results...") | |
| output_data = [] | |
| # Use sdf index columns + prediction results | |
| index_cols = [col for col in config.index_columns if col in sdf.df.columns] | |
| base_df_pd = sdf.df.select(index_cols).to_pandas() # Get base info | |
| metrics_calc = Metrics(RESIDUE_SET, config.isotope_error_range) | |
| for i, res in enumerate(results_list): | |
| row_data = base_df_pd.iloc[i].to_dict() # Get corresponding input data | |
| if isinstance(res, ScoredSequence) and res.sequence: | |
| sequence_str = "".join(res.sequence) | |
| row_data["prediction"] = sequence_str | |
| row_data["log_probability"] = f"{res.sequence_log_probability:.4f}" | |
| # Use metrics to calculate delta mass ppm for the top prediction | |
| try: | |
| _, delta_mass_list = metrics_calc.matches_precursor( | |
| res.sequence, | |
| row_data["precursor_mz"], | |
| row_data["precursor_charge"] | |
| ) | |
| # Find the smallest absolute ppm error across isotopes | |
| min_abs_ppm = min(abs(p) for p in delta_mass_list) if delta_mass_list else float('nan') | |
| row_data["delta_mass_ppm"] = f"{min_abs_ppm:.2f}" | |
| except Exception as e: | |
| print(f"Warning: Could not calculate delta mass for prediction {i}: {e}") | |
| row_data["delta_mass_ppm"] = "N/A" | |
| else: | |
| row_data["prediction"] = "" | |
| row_data["log_probability"] = "N/A" | |
| row_data["delta_mass_ppm"] = "N/A" | |
| output_data.append(row_data) | |
| output_df = pl.DataFrame(output_data) | |
| # Ensure specific columns are present and ordered | |
| display_cols = ["scan_number", "precursor_mz", "precursor_charge", "prediction", "log_probability", "delta_mass_ppm"] | |
| final_display_cols = [] | |
| for col in display_cols: | |
| if col in output_df.columns: | |
| final_display_cols.append(col) | |
| else: | |
| print(f"Warning: Expected display column '{col}' not found in results.") | |
| # Add any remaining index columns that weren't in display_cols | |
| for col in index_cols: | |
| if col not in final_display_cols and col in output_df.columns: | |
| final_display_cols.append(col) | |
| output_df_display = output_df.select(final_display_cols) | |
| # 7. Save full results to CSV | |
| print(f"Saving results to {output_csv_path}...") | |
| output_df.write_csv(output_csv_path) | |
| # Return DataFrame for display and path for download | |
| return output_df_display.to_pandas(), output_csv_path | |
| except Exception as e: | |
| print(f"An error occurred during prediction: {e}") | |
| # Clean up the temporary output file if it exists | |
| if os.path.exists(output_csv_path): | |
| os.remove(output_csv_path) | |
| # Re-raise as Gradio error | |
| raise gr.Error(f"Prediction failed: {e}") | |
| # --- Gradio Interface --- | |
| css = """ | |
| .gradio-container { font-family: sans-serif; } | |
| .gr-button { color: white; border-color: black; background: black; } | |
| footer { display: none !important; } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as demo: | |
| gr.Markdown( | |
| """ | |
| # π InstaNovo _De Novo_ Peptide Sequencing | |
| Upload your mass spectrometry data file (.mgf, .mzml, or .mzxml) and get peptide sequence predictions using InstaNovo. | |
| Choose between fast Greedy Search or more accurate but slower Knapsack Beam Search. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_file = gr.File( | |
| label="Upload Mass Spectrometry File (.mgf, .mzml, .mzxml)", | |
| file_types=[".mgf", ".mzml", ".mzxml"] | |
| ) | |
| decoding_method = gr.Radio( | |
| ["Greedy Search (Fast)", "Knapsack Beam Search (More accurate, but slower)"], | |
| label="Decoding Method", | |
| value="Greedy Search (Fast)" # Default to fast method | |
| ) | |
| submit_btn = gr.Button("Predict Sequences", variant="primary") | |
| with gr.Column(scale=2): | |
| output_df = gr.DataFrame(label="Prediction Results", wrap=True) | |
| output_file = gr.File(label="Download Full Results (CSV)") | |
| submit_btn.click( | |
| predict_peptides, | |
| inputs=[input_file, decoding_method], | |
| outputs=[output_df, output_file] | |
| ) | |
| gr.Examples( | |
| [["./sample_spectra.mgf", "Knapsack Beam Search (Accurate, 5 Beams)"]], # Requires test data fetched | |
| inputs=[input_file, decoding_method], | |
| outputs=[output_df, output_file], | |
| fn=predict_peptides, | |
| cache_examples=False, # Re-run examples if needed | |
| label="Example Usage" | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Notes:** | |
| * Predictions are based on the [InstaNovo](https://github.com/instadeepai/InstaNovo) model ({MODEL_ID}). | |
| * Knapsack Beam Search uses pre-calculated mass constraints and yields better results but takes longer. | |
| * 'delta_mass_ppm' shows the lowest absolute precursor mass error (in ppm) across potential isotopes (0-1 neutron). | |
| * Ensure your input file format is correctly specified. Large files may take time to process. | |
| """.format(MODEL_ID=MODEL_ID) | |
| ) | |
| # --- Launch the App --- | |
| if __name__ == "__main__": | |
| # Set share=True for temporary public link if running locally | |
| # Set server_name="0.0.0.0" to allow access from network if needed | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |
| # For Hugging Face Spaces, just demo.launch() is usually sufficient | |
| demo.launch(share=True) # For local testing with public URL |