Polarium
AI Text Assistant
c76198f
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import json
from typing import Dict, List, Tuple
import numpy as np
# Global variables for models
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Model names
TEXT_GEN_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
SUMMARIZATION_MODEL = "facebook/bart-large-cnn"
# Load models and tokenizers
print("Loading models...")
gen_tokenizer = AutoTokenizer.from_pretrained(TEXT_GEN_MODEL)
gen_model = AutoModelForCausalLM.from_pretrained(TEXT_GEN_MODEL).to(device)
sum_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZATION_MODEL).to(device)
print("Models loaded successfully!")
def count_words(text: str) -> int:
"""Count words in text"""
return len(text.split())
def generate_text_with_alternatives(
input_text: str,
max_tokens: int = 100
) -> Tuple[str, List[Dict]]:
"""
Generate text and capture top-5 alternative tokens for each generated token.
Returns: (generated_text, token_alternatives)
"""
# Prepare input
messages = [{"role": "user", "content": input_text}]
text = gen_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = gen_tokenizer(text, return_tensors="pt").to(device)
# Generate with output_scores to get token probabilities
with torch.no_grad():
outputs = gen_model.generate(
**inputs,
max_new_tokens=max_tokens,
output_scores=True,
return_dict_in_generate=True,
do_sample=False, # Greedy decoding
pad_token_id=gen_tokenizer.eos_token_id
)
# Get generated tokens (excluding input)
generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:]
generated_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True)
# Extract token alternatives from scores
token_alternatives = []
if hasattr(outputs, 'scores') and outputs.scores:
for score_tensor in outputs.scores:
# Get probabilities
probs = torch.nn.functional.softmax(score_tensor[0], dim=-1)
# Get top 5 tokens
top_probs, top_indices = torch.topk(probs, k=5)
alternatives = []
for prob, idx in zip(top_probs, top_indices):
token = gen_tokenizer.decode([idx.item()])
alternatives.append({
"token": token,
"probability": f"{prob.item() * 100:.2f}%"
})
token_alternatives.append(alternatives)
return generated_text, token_alternatives
def summarize_text_with_alternatives(
input_text: str,
max_tokens: int = 100
) -> Tuple[str, List[Dict]]:
"""
Summarize text and capture top-5 alternative tokens for each generated token.
Returns: (summary_text, token_alternatives)
"""
inputs = sum_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device)
# Generate with output_scores
with torch.no_grad():
outputs = sum_model.generate(
**inputs,
max_length=max_tokens,
output_scores=True,
return_dict_in_generate=True,
do_sample=False, # Greedy decoding
)
# Decode summary
summary_text = sum_tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Extract token alternatives
token_alternatives = []
if hasattr(outputs, 'scores') and outputs.scores:
for score_tensor in outputs.scores:
probs = torch.nn.functional.softmax(score_tensor[0], dim=-1)
top_probs, top_indices = torch.topk(probs, k=5)
alternatives = []
for prob, idx in zip(top_probs, top_indices):
token = sum_tokenizer.decode([idx.item()])
alternatives.append({
"token": token,
"probability": f"{prob.item() * 100:.2f}%"
})
token_alternatives.append(alternatives)
return summary_text, token_alternatives
def create_html_with_tooltips(text: str, token_alternatives: List[Dict]) -> str:
"""
Create HTML with hoverable words that show token alternatives.
"""
if not token_alternatives:
return f"<div style='padding: 20px; font-size: 16px;'>{text}</div>"
# Split text into tokens/words for display
words = text.split()
html_parts = []
html_parts.append("""
<style>
.word-container {
display: inline-block;
position: relative;
margin: 2px;
padding: 2px 4px;
cursor: pointer;
border-radius: 3px;
transition: background-color 0.2s;
}
.word-container:hover {
background-color: #e3f2fd;
}
.tooltip {
visibility: hidden;
position: absolute;
z-index: 1000;
background-color: #263238;
color: white;
padding: 12px;
border-radius: 6px;
font-size: 13px;
min-width: 250px;
bottom: 125%;
left: 50%;
transform: translateX(-50%);
box-shadow: 0 4px 6px rgba(0,0,0,0.3);
opacity: 0;
transition: opacity 0.3s;
}
.tooltip::after {
content: "";
position: absolute;
top: 100%;
left: 50%;
margin-left: -5px;
border-width: 5px;
border-style: solid;
border-color: #263238 transparent transparent transparent;
}
.word-container:hover .tooltip {
visibility: visible;
opacity: 1;
}
.alternative-item {
padding: 4px 0;
border-bottom: 1px solid #37474f;
}
.alternative-item:last-child {
border-bottom: none;
}
.token-text {
font-weight: bold;
color: #81d4fa;
}
.probability {
float: right;
color: #a5d6a7;
}
.result-container {
padding: 20px;
font-size: 16px;
line-height: 1.8;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
}
</style>
<div class='result-container'>
""")
# Map words to token alternatives (approximate mapping)
alt_index = 0
for word in words:
if alt_index < len(token_alternatives):
alternatives = token_alternatives[alt_index]
# Create tooltip content
tooltip_html = "<div class='tooltip'>"
tooltip_html += "<div style='margin-bottom: 8px; font-weight: bold; border-bottom: 2px solid #37474f; padding-bottom: 4px;'>Top 5 Alternatives:</div>"
for i, alt in enumerate(alternatives, 1):
tooltip_html += f"<div class='alternative-item'>"
tooltip_html += f"<span>{i}. <span class='token-text'>{alt['token']}</span></span>"
tooltip_html += f"<span class='probability'>{alt['probability']}</span>"
tooltip_html += f"</div>"
tooltip_html += "</div>"
html_parts.append(f"<span class='word-container'>{word}{tooltip_html}</span>")
alt_index += 1
else:
html_parts.append(f"<span class='word-container'>{word}</span>")
html_parts.append("</div>")
return "".join(html_parts)
def process_text(input_text: str, mode: str, max_tokens: int) -> Tuple[str, str]:
"""
Main processing function that handles both text generation and summarization.
Returns: (result_html, status_message)
"""
if not input_text or not input_text.strip():
return "<div style='padding: 20px; color: red;'>Please enter some text to process.</div>", "❌ No input provided"
# Check word count
word_count = count_words(input_text)
if word_count > 500:
return f"<div style='padding: 20px; color: red;'>Input exceeds maximum limit of 500 words. Current: {word_count} words.</div>", f"❌ Input too long ({word_count} words)"
try:
if mode == "Text Generation":
status = f"πŸ”„ Generating text (max {max_tokens} tokens)..."
generated_text, alternatives = generate_text_with_alternatives(input_text, max_tokens)
result_html = create_html_with_tooltips(generated_text, alternatives)
return result_html, f"βœ… Generated {len(alternatives)} tokens"
else: # Text Summarization
status = f"πŸ”„ Summarizing text (max {max_tokens} tokens)..."
summary_text, alternatives = summarize_text_with_alternatives(input_text, max_tokens)
result_html = create_html_with_tooltips(summary_text, alternatives)
return result_html, f"βœ… Generated {len(alternatives)} tokens"
except Exception as e:
error_msg = f"<div style='padding: 20px; color: red;'>Error: {str(e)}</div>"
return error_msg, f"❌ Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="AI Text Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ€– AI Text Assistant
Generate text or summarize articles using state-of-the-art AI models.
**Hover over any word** in the result to see the top 5 alternative tokens the AI considered!
""")
with gr.Row():
with gr.Column(scale=2):
mode = gr.Radio(
choices=["Text Generation", "Text Summarization"],
value="Text Generation",
label="Mode",
info="Choose between generating new text or summarizing existing text"
)
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter your text here... (max 500 words)",
lines=6,
max_lines=10
)
with gr.Row():
max_tokens = gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Max Tokens",
info="Maximum number of tokens to generate"
)
process_btn = gr.Button("πŸš€ Process", variant="primary", size="lg")
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
output_html = gr.HTML(label="Result")
gr.Markdown("""
### πŸ’‘ Tips:
- **Text Generation**: Provide a prompt and the AI will continue writing
- **Text Summarization**: Paste an article or long text to get a concise summary
- **Hover** over any word in the output to see what other words the AI considered
- Models used: Qwen/Qwen2.5-0.5B-Instruct (generation) & facebook/bart-large-cnn (summarization)
""")
# Connect the button to the processing function
process_btn.click(
fn=process_text,
inputs=[input_text, mode, max_tokens],
outputs=[output_html, status]
)
if __name__ == "__main__":
demo.launch()