Streamlit-chatbot / src /streamlit_app.py
Maitreyee22's picture
Update src/streamlit_app.py
23ed344 verified
import os
# # 1. Create the folder if it doesn't exist
# cache_root = "/mnt/data/huggingface"
# os.makedirs(cache_root, exist_ok=True)
# # 2. Redirect Hugging Face Hub, Transformers & Datasets caches
# os.environ["HF_HOME"] = cache_root
# os.environ["TRANSFORMERS_CACHE"] = os.path.join(cache_root, "transformers")
# os.environ["HF_DATASETS_CACHE"] = os.path.join(cache_root, "datasets")
import streamlit as st
# Import the high-level pipeline API from Hugging Face Transformers
# It simplifies loading models/tokenizers and running common tasks
from transformers import pipeline
#hf_token = st.secrets["HF_HUB_TOKEN"]
hf_token= os.getenv("HUGGING_FACE_HUB_TOKEN")
# 1. Cache the pipeline so it loads once
@st.cache_resource
def get_generator():
# Initialize a text-to-text generation pipeline:
# - "text2text-generation" tells the pipeline we want a seq2seq model (T5 family)
# - model="google/flan-t5-small" specifies which pretrained model to load
# The pipeline object wraps both tokenizer and model for you.
return pipeline("text2text-generation", model="google/flan-t5-small", use_auth_token=hf_token)
generator = get_generator()
st.title("📝 FLAN-T5 Text-to-Text Generator")
st.write("Enter a prompt below and hit Generate to see the model’s output.")
# 2. Prompt the user for input
user_input = st.text_area("Your prompt:", height=120)
# 3. Generation settings in the sidebar
with st.sidebar:
st.header("Generation Settings")
max_length = st.slider("Max output length", min_value=16, max_value=200, value=50)
num_beams = st.slider("Beam search width", min_value=1, max_value=8, value=4)
do_sample = st.checkbox("Enable sampling", value=False)
top_k = st.slider("Top-k sampling", min_value=0, max_value=100, value=50)
temperature= st.slider("Temperature", min_value=0.1, max_value=2.0, value=1.0, step=0.1)
# 4. Generate button
if st.button("🔄 Generate"):
if not user_input.strip():
st.error("Please enter a prompt first.")
else:
with st.spinner("Generating…"):
outputs = generator(user_input)
# pipeline returns list of dicts with key "generated_text"
result = outputs[0]["generated_text"]
st.subheader("Output")
st.write(result)