|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from wordcloud import WordCloud |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(): |
|
|
model_name = "google/gemma-2b" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text(prompt, tone, max_length, temperature=0.7, top_p=0.9, repetition_penalty=1.0): |
|
|
tokenizer, model = load_model_and_tokenizer() |
|
|
|
|
|
|
|
|
tone_prompts = { |
|
|
"Funny": f"Instruction: Generate a concise, humorous response to the following prompt. Prompt: {prompt}. Use witty wordplay, unexpected twists, or lighthearted exaggeration, avoiding offensive content. Aim for a punchline-style finish.", |
|
|
"Serious": f"Instruction: Provide a detailed, thoughtful, and professional response to the following prompt. Prompt: {prompt}. Offer logical reasoning, depth, and a formal tone, as if explaining to an expert audience.", |
|
|
"Poetic": f"Instruction: Write a vivid, poetic response to the following prompt. Prompt: {prompt}. Use metaphor, rhythm, and imagery to create a lyrical flow, as if crafting a short verse." |
|
|
} |
|
|
|
|
|
input_text = tone_prompts.get(tone, prompt) |
|
|
|
|
|
inputs = tokenizer(input_text, return_tensors="pt") |
|
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
input_token_length = input_ids.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
estimated_max_tokens = int(max_length * 1.75) |
|
|
|
|
|
estimated_max_tokens = max(estimated_max_tokens, 30) |
|
|
|
|
|
outputs = model.generate( |
|
|
inputs["input_ids"], |
|
|
|
|
|
|
|
|
|
|
|
max_new_tokens = estimated_max_tokens, |
|
|
|
|
|
temperature = temperature, |
|
|
|
|
|
top_p=top_p, |
|
|
|
|
|
repetition_penalty = repetition_penalty, |
|
|
|
|
|
num_return_sequences = 1, |
|
|
|
|
|
do_sample = True, |
|
|
pad_token_id = tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
generated_token_ids = outputs[0, input_token_length:] |
|
|
generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
return generated_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
/* Solid background with minimal decoration */ |
|
|
.stApp { |
|
|
background-color: #1E1E3A; |
|
|
color: #f0f0f0; |
|
|
} |
|
|
|
|
|
/* Clean container styling for all elements */ |
|
|
.stForm, div.stButton, .stMarkdown, div.stSlider { |
|
|
border-radius: 8px; |
|
|
transition: all 0.2s ease; |
|
|
} |
|
|
|
|
|
/* Professional title styling */ |
|
|
.title { |
|
|
color: #ffffff; |
|
|
font-size: 32px; |
|
|
font-weight: 700; |
|
|
padding: 15px 0; |
|
|
font-family: 'Segoe UI', Arial, sans-serif; |
|
|
letter-spacing: 0.5px; |
|
|
margin-bottom: 15px; |
|
|
border-bottom: 2px solid #4A5BEA; |
|
|
display: inline-block; |
|
|
} |
|
|
|
|
|
/* Clean instructions card */ |
|
|
.instructions { |
|
|
background: rgba(255, 255, 255, 0.05); |
|
|
padding: 20px; |
|
|
border-radius: 8px; |
|
|
box-shadow: 0 4px 6px rgba(0,0,0,0.1); |
|
|
font-size: 16px; |
|
|
color: #f0f0f0; |
|
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
|
margin-bottom: 20px; |
|
|
line-height: 1.5; |
|
|
} |
|
|
|
|
|
/* Professional output box */ |
|
|
.output-box { |
|
|
background: rgba(20, 20, 40, 0.8); |
|
|
padding: 20px; |
|
|
border-radius: 8px; |
|
|
box-shadow: 0 4px 8px rgba(0,0,0,0.2); |
|
|
font-family: 'Courier New', monospace; |
|
|
font-size: 16px; |
|
|
color: #ffffff; |
|
|
white-space: pre-wrap; |
|
|
border: 1px solid #4A5BEA; |
|
|
margin-top: 1.5rem; |
|
|
} |
|
|
|
|
|
/* Form styling */ |
|
|
.stTextInput>div>div>input { |
|
|
background: rgba(255, 255, 255, 0.05); |
|
|
border: 1px solid rgba(255, 255, 255, 0.2); |
|
|
color: white; |
|
|
border-radius: 6px; |
|
|
padding: 10px; |
|
|
font-size: 16px; |
|
|
} |
|
|
.stTextInput>div>div>input:focus { |
|
|
border-color: #4A5BEA; |
|
|
box-shadow: 0 0 5px rgba(74, 91, 234, 0.5); |
|
|
} |
|
|
|
|
|
/* Clean, professional buttons */ |
|
|
.stButton>button { |
|
|
background-color: #4A5BEA; |
|
|
color: white; |
|
|
border: none; |
|
|
border-radius: 6px; |
|
|
padding: 10px 15px; |
|
|
font-size: 16px; |
|
|
font-weight: 500; |
|
|
box-shadow: 0 2px 5px rgba(0,0,0,0.2); |
|
|
transition: all 0.2s; |
|
|
text-transform: none; |
|
|
letter-spacing: 0.5px; |
|
|
} |
|
|
.stButton>button:hover { |
|
|
background-color: #3A4AC0; |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 4px 8px rgba(0,0,0,0.2); |
|
|
} |
|
|
|
|
|
/* Example button styling */ |
|
|
.example-button { |
|
|
background-color: #2D3250; |
|
|
color: white; |
|
|
border: 1px solid #4A5BEA; |
|
|
border-radius: 6px; |
|
|
padding: 8px 16px; |
|
|
margin: 5px; |
|
|
cursor: pointer; |
|
|
transition: all 0.2s; |
|
|
font-size: 14px; |
|
|
font-weight: 500; |
|
|
} |
|
|
.example-button:hover { |
|
|
background-color: #4A5BEA; |
|
|
} |
|
|
|
|
|
/* Enhanced sliders */ |
|
|
.stSlider>div>div>div { |
|
|
background: #4A5BEA !important; |
|
|
} |
|
|
|
|
|
/* Improved slider number values styling */ |
|
|
.stSlider p { |
|
|
color: #f0f0f0 !important; |
|
|
} |
|
|
|
|
|
/* Fix for the blue box numbers on sliders */ |
|
|
.stSlider > div > div > div > div > div > div { |
|
|
background-color: #4A5BEA !important; |
|
|
color: white !important; |
|
|
font-weight: 500 !important; |
|
|
border: none !important; |
|
|
border-radius: 4px !important; |
|
|
padding: 2px 6px !important; |
|
|
font-size: 14px !important; |
|
|
} |
|
|
|
|
|
/* Number value containers at the ends of sliders */ |
|
|
.stSlider > div > div > div:first-child, |
|
|
.stSlider > div > div > div:last-child { |
|
|
background-color: #292952 !important; |
|
|
color: #f0f0f0 !important; |
|
|
border: 1px solid rgba(255, 255, 255, 0.1) !important; |
|
|
border-radius: 4px !important; |
|
|
padding: 4px 8px !important; |
|
|
font-size: 14px !important; |
|
|
font-family: monospace !important; |
|
|
} |
|
|
|
|
|
/* Custom selectbox styling */ |
|
|
.stSelectbox>div>div { |
|
|
background: rgba(255, 255, 255, 0.05); |
|
|
border: 1px solid rgba(255, 255, 255, 0.2); |
|
|
border-radius: 6px; |
|
|
} |
|
|
|
|
|
/* Expander styling */ |
|
|
.streamlit-expanderHeader { |
|
|
background: rgba(255, 255, 255, 0.05); |
|
|
border-radius: 6px; |
|
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
|
color: #f0f0f0 !important; |
|
|
font-weight: 500; |
|
|
} |
|
|
|
|
|
/* WordCloud container */ |
|
|
.wordcloud-container { |
|
|
background: rgba(255, 255, 255, 0.05); |
|
|
border-radius: 8px; |
|
|
padding: 15px; |
|
|
box-shadow: 0 4px 6px rgba(0,0,0,0.1); |
|
|
margin-top: 20px; |
|
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
|
} |
|
|
|
|
|
/* Footer styling */ |
|
|
.footer { |
|
|
background: rgba(0, 0, 0, 0.2); |
|
|
border-radius: 8px; |
|
|
padding: 15px; |
|
|
margin-top: 30px; |
|
|
border-top: 1px solid rgba(255, 255, 255, 0.1); |
|
|
text-align: center; |
|
|
} |
|
|
|
|
|
/* Parameter card styling */ |
|
|
.parameter-card { |
|
|
background: rgba(255, 255, 255, 0.05); |
|
|
border-radius: 8px; |
|
|
padding: 15px; |
|
|
margin-bottom: 15px; |
|
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
|
} |
|
|
.parameter-card h4 { |
|
|
color: #4A5BEA; |
|
|
margin-top: 0; |
|
|
font-size: 16px; |
|
|
font-weight: 500; |
|
|
} |
|
|
|
|
|
/* Scrollbar styling */ |
|
|
::-webkit-scrollbar { |
|
|
width: 8px; |
|
|
height: 8px; |
|
|
} |
|
|
::-webkit-scrollbar-track { |
|
|
background: rgba(255, 255, 255, 0.05); |
|
|
border-radius: 4px; |
|
|
} |
|
|
::-webkit-scrollbar-thumb { |
|
|
background: #4A5BEA; |
|
|
border-radius: 4px; |
|
|
} |
|
|
::-webkit-scrollbar-thumb:hover { |
|
|
background: #3A4AC0; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div style="display: flex; align-items: center; justify-content: space-between; margin-bottom: 25px;"> |
|
|
<h1 class="title">Gemma Text Generator</h1> |
|
|
<img src="https://huggingface.co/spaces/Kakaarot/Gemma-HuggingFace_TextCompletion_Demo/resolve/main/images/GSoC_logo.png" width="100" alt="GSoC 2025" style="margin-top: 5px;"> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div class="instructions"> |
|
|
<b><a href = "https://huggingface.co/spaces/Kakaarot/Gemma-HuggingFace_TextCompletion_Demo/discussions/1">Please check the discussion</a><b>, I mentioned there the reason, why your first response will take little more time. |
|
|
Thanks for understanding, Now Enjoyyy 😁 <br><br> |
|
|
Enter a prompt below to generate text using the Gemma model from DeepMind. Customize the tone and length to see different outputs!<br><br> |
|
|
<b>Example:</b> Prompt: "The cat sat on" | Tone: "Funny" | Length: 50 → "The cat sat on my homework and laughed as I cried over my grades." |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
with st.expander("How does this work?"): |
|
|
st.markdown(""" |
|
|
<div style="padding: 10px;"> |
|
|
<ul> |
|
|
<li style="margin-bottom: 8px;">This app uses <b>Gemma-2B</b>, a language model from Google DeepMind.</li> |
|
|
<li style="margin-bottom: 8px;">You give it a prompt, and it predicts the next words one-by-one (aka causal language modeling).</li> |
|
|
<li style="margin-bottom: 8px;">The <b>tone</b> you choose adds flavor to the prompt before it hits the model.</li> |
|
|
<li style="margin-bottom: 8px;">Parameters like <b>temperature</b> control how wild or safe the answers are.</li> |
|
|
<li>The output is visualized in a <b>Word Cloud</b> so you can see which words stand out!</li> |
|
|
</ul> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown("<p style='margin-top: 20px; margin-bottom: 10px;'>Try these examples:</p>", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
|
|
|
if "trigger_example" not in st.session_state: |
|
|
st.session_state.trigger_example = False |
|
|
|
|
|
with col1: |
|
|
if st.button("✨ Funny Cat Story"): |
|
|
st.session_state.prompt = "The cat hacked my WiFi" |
|
|
|
|
|
st.session_state.tone = "Funny" |
|
|
|
|
|
st.session_state.trigger_example = True |
|
|
|
|
|
|
|
|
with col2: |
|
|
if st.button("🌅 Poetic Goodbye"): |
|
|
st.session_state.prompt = "As the sun set on our final day" |
|
|
st.session_state.tone = "Poetic" |
|
|
st.session_state.trigger_example = True |
|
|
with col3: |
|
|
if st.button("🧠 Serious AI Future"): |
|
|
st.session_state.prompt = "The future of AI is" |
|
|
st.session_state.tone = "Serious" |
|
|
st.session_state.trigger_example = True |
|
|
|
|
|
|
|
|
with st.form(key="input_form"): |
|
|
|
|
|
st.markdown('<div style="margin-bottom: 15px;"><h3 style="color: #4A5BEA; margin-bottom: 10px;">Generate Your Text</h3></div>', unsafe_allow_html=True) |
|
|
|
|
|
prompt = st.text_input("Enter a prompt", placeholder="e.g., 'The future of AI is'", value=st.session_state.get("prompt", "")) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
tone = st.selectbox("Tone", ["Funny", "Serious", "Poetic"], index=["Funny", "Serious", "Poetic"].index(st.session_state.get("tone", "Funny"))) |
|
|
with col2: |
|
|
max_length = st.slider("Word count", 20, 100, 50, help="Tries to generate text close to this word count. Output might be shorter if the model finishes early, or slightly different due to word splitting. I am considering 1.75 tokens as one word.") |
|
|
|
|
|
|
|
|
|
|
|
st.markdown('<div class="parameter-card"><h4>Advanced Parameters</h4>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
/* Custom number styling for slider values */ |
|
|
.value-display { |
|
|
background-color: #292952; |
|
|
padding: 4px 8px; |
|
|
border-radius: 4px; |
|
|
font-family: monospace; |
|
|
border: 1px solid rgba(255, 255, 255, 0.1); |
|
|
font-size: 14px; |
|
|
display: inline-block; |
|
|
margin-left: 5px; |
|
|
color: #f0f0f0; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
temperature = st.slider("Temperature (Creativity)", 0.2, 1.5, 0.7, help="Higher values make output more random") |
|
|
with col2: |
|
|
top_p = st.slider("Top-p (Nucleus Sampling)", 0.1, 1.0, 0.9, help="Controls diversity") |
|
|
|
|
|
repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.0, help="Higher values discourage repetition") |
|
|
|
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
submit_button = st.form_submit_button(label="Generate") |
|
|
|
|
|
|
|
|
if submit_button or st.session_state.trigger_example: |
|
|
|
|
|
st.session_state.trigger_example = False |
|
|
|
|
|
if not prompt: |
|
|
st.error("Please enter a prompt!") |
|
|
|
|
|
else: |
|
|
with st.spinner("Generating text..."): |
|
|
|
|
|
output = generate_text(prompt, tone, max_length, temperature, top_p, repetition_penalty) |
|
|
|
|
|
|
|
|
st.markdown(f""" |
|
|
<div style="background: rgba(255, 255, 255, 0.05); border-radius: 8px; padding: 12px; |
|
|
margin-bottom: 15px; font-size: 14px; border: 1px solid rgba(255, 255, 255, 0.1);"> |
|
|
<span style="color: #4A5BEA; font-weight: 500;">Tone:</span> |
|
|
<span class="value-display">{tone}</span> | |
|
|
<span style="color: #4A5BEA; font-weight: 500;">Temperature:</span> |
|
|
<span class="value-display">{temperature:.2f}</span> | |
|
|
<span style="color: #4A5BEA; font-weight: 500;">Words:</span> |
|
|
<span class="value-display">~{max_length}</span> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown(f'<div class="output-box">{output}</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown('<div class="wordcloud-container">', unsafe_allow_html=True) |
|
|
st.markdown('<h4 style="color: #4A5BEA; margin-top: 0;">Word Cloud Visualization</h4>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
wordcloud = WordCloud( |
|
|
width=600, |
|
|
height=300, |
|
|
background_color="#1E1E3A", |
|
|
colormap="viridis", |
|
|
max_words=100, |
|
|
contour_width=0 |
|
|
).generate(output) |
|
|
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
plt.imshow(wordcloud, interpolation="bilinear") |
|
|
plt.axis("off") |
|
|
st.pyplot(plt) |
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div class="footer"> |
|
|
<p>Built with ❤️ by Utkarsh Shukla for GSoC Proposal 2025 | Powered by Gemma + Hugging Face</p> |
|
|
<p style="margin-top: 10px;"> |
|
|
<img src="https://huggingface.co/spaces/Kakaarot/Gemma-HuggingFace_TextCompletion_Demo/resolve/main/images/google-gemini-icon.png" width="100" alt="Gemma by DeepMind"> |
|
|
</p> |
|
|
<p style="margin-top: 10px;">Wish me luck! 🤞</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|