predict-tfbs / gradio_app.py
moslem's picture
Upload Model
45e8fda verified
# gradio_app.py
# Gradio rewrite of the original Streamlit app for TFBS prediction
import os
import gradio as gr
import numpy as np
from utils import dnaseq_features
from keras.models import load_model
# Load model once at startup (make sure best_model.h5 is present in the repo)
model = load_model('best_model.h5')
def predict(dna_seq: str):
"""Take a DNA sequence string, compute features, run the model and return a Markdown summary."""
dna_seq = (dna_seq or "").strip()
if not dna_seq:
return "**Please provide an input DNA sequence.**"
# compute features using user's util function
try:
dna_ohe_feat, ds_index, ds_val = dnaseq_features(seq=dna_seq)
except Exception as e:
return f"Error while computing features: {e}"
# predict
try:
predicted = model.predict(dna_ohe_feat)
except Exception as e:
return f"Error during model prediction: {e}"
# build output as Markdown list with probabilities/scores
lines = []
for i, j in zip(ds_val, predicted):
try:
# if prediction is a vector like [p0, p1]
if hasattr(j, "__len__") and len(j) >= 2:
prob = float(j[1])
label = "**TFBS found βœ…**" if np.argmax(j) == 1 else "**TFBS not found ❌**"
lines.append(f"- `{i}` β€” {label} (probability: {prob:.4f})")
else:
# single scalar output
score = float(j)
label = "**TFBS found βœ…**" if score > 0.5 else "**TFBS not found ❌**"
lines.append(f"- `{i}` β€” {label} (score: {score:.4f})")
except Exception:
# fallback
lines.append(f"- `{i}` β€” prediction: {j}")
if not lines:
return "No predictions were produced. Check the input or the model."
return "\n".join(lines)
# Gradio UI
title = "Simple Model Serving Web App for TFBS prediction"
description = "Get TFBS predictions from the latest model. Paste a DNA sequence and click **Make Prediction**."
with gr.Blocks(theme=None) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
seq_input = gr.Textbox(lines=6, placeholder="ATAGAGAC...", label="Input DNA sequence")
with gr.Row():
predict_btn = gr.Button("Make Prediction")
output = gr.Markdown()
predict_btn.click(fn=predict, inputs=seq_input, outputs=output)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port)