# gradio_app.py # Gradio rewrite of the original Streamlit app for TFBS prediction import os import gradio as gr import numpy as np from utils import dnaseq_features from keras.models import load_model # Load model once at startup (make sure best_model.h5 is present in the repo) model = load_model('best_model.h5') def predict(dna_seq: str): """Take a DNA sequence string, compute features, run the model and return a Markdown summary.""" dna_seq = (dna_seq or "").strip() if not dna_seq: return "**Please provide an input DNA sequence.**" # compute features using user's util function try: dna_ohe_feat, ds_index, ds_val = dnaseq_features(seq=dna_seq) except Exception as e: return f"Error while computing features: {e}" # predict try: predicted = model.predict(dna_ohe_feat) except Exception as e: return f"Error during model prediction: {e}" # build output as Markdown list with probabilities/scores lines = [] for i, j in zip(ds_val, predicted): try: # if prediction is a vector like [p0, p1] if hasattr(j, "__len__") and len(j) >= 2: prob = float(j[1]) label = "**TFBS found ✅**" if np.argmax(j) == 1 else "**TFBS not found ❌**" lines.append(f"- `{i}` — {label} (probability: {prob:.4f})") else: # single scalar output score = float(j) label = "**TFBS found ✅**" if score > 0.5 else "**TFBS not found ❌**" lines.append(f"- `{i}` — {label} (score: {score:.4f})") except Exception: # fallback lines.append(f"- `{i}` — prediction: {j}") if not lines: return "No predictions were produced. Check the input or the model." return "\n".join(lines) # Gradio UI title = "Simple Model Serving Web App for TFBS prediction" description = "Get TFBS predictions from the latest model. Paste a DNA sequence and click **Make Prediction**." with gr.Blocks(theme=None) as demo: gr.Markdown(f"# {title}") gr.Markdown(description) with gr.Row(): seq_input = gr.Textbox(lines=6, placeholder="ATAGAGAC...", label="Input DNA sequence") with gr.Row(): predict_btn = gr.Button("Make Prediction") output = gr.Markdown() predict_btn.click(fn=predict, inputs=seq_input, outputs=output) if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) demo.launch(server_name="0.0.0.0", server_port=port)