Spaces:
Sleeping
Sleeping
| # gradio_app.py | |
| # Gradio rewrite of the original Streamlit app for TFBS prediction | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| from utils import dnaseq_features | |
| from keras.models import load_model | |
| # Load model once at startup (make sure best_model.h5 is present in the repo) | |
| model = load_model('best_model.h5') | |
| def predict(dna_seq: str): | |
| """Take a DNA sequence string, compute features, run the model and return a Markdown summary.""" | |
| dna_seq = (dna_seq or "").strip() | |
| if not dna_seq: | |
| return "**Please provide an input DNA sequence.**" | |
| # compute features using user's util function | |
| try: | |
| dna_ohe_feat, ds_index, ds_val = dnaseq_features(seq=dna_seq) | |
| except Exception as e: | |
| return f"Error while computing features: {e}" | |
| # predict | |
| try: | |
| predicted = model.predict(dna_ohe_feat) | |
| except Exception as e: | |
| return f"Error during model prediction: {e}" | |
| # build output as Markdown list with probabilities/scores | |
| lines = [] | |
| for i, j in zip(ds_val, predicted): | |
| try: | |
| # if prediction is a vector like [p0, p1] | |
| if hasattr(j, "__len__") and len(j) >= 2: | |
| prob = float(j[1]) | |
| label = "**TFBS found β **" if np.argmax(j) == 1 else "**TFBS not found β**" | |
| lines.append(f"- `{i}` β {label} (probability: {prob:.4f})") | |
| else: | |
| # single scalar output | |
| score = float(j) | |
| label = "**TFBS found β **" if score > 0.5 else "**TFBS not found β**" | |
| lines.append(f"- `{i}` β {label} (score: {score:.4f})") | |
| except Exception: | |
| # fallback | |
| lines.append(f"- `{i}` β prediction: {j}") | |
| if not lines: | |
| return "No predictions were produced. Check the input or the model." | |
| return "\n".join(lines) | |
| # Gradio UI | |
| title = "Simple Model Serving Web App for TFBS prediction" | |
| description = "Get TFBS predictions from the latest model. Paste a DNA sequence and click **Make Prediction**." | |
| with gr.Blocks(theme=None) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| seq_input = gr.Textbox(lines=6, placeholder="ATAGAGAC...", label="Input DNA sequence") | |
| with gr.Row(): | |
| predict_btn = gr.Button("Make Prediction") | |
| output = gr.Markdown() | |
| predict_btn.click(fn=predict, inputs=seq_input, outputs=output) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| demo.launch(server_name="0.0.0.0", server_port=port) | |