import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM import json from typing import Dict, List, Tuple import numpy as np # Global variables for models device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Model names TEXT_GEN_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" SUMMARIZATION_MODEL = "facebook/bart-large-cnn" # Load models and tokenizers print("Loading models...") gen_tokenizer = AutoTokenizer.from_pretrained(TEXT_GEN_MODEL) gen_model = AutoModelForCausalLM.from_pretrained(TEXT_GEN_MODEL).to(device) sum_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL) sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZATION_MODEL).to(device) print("Models loaded successfully!") def count_words(text: str) -> int: """Count words in text""" return len(text.split()) def generate_text_with_alternatives( input_text: str, max_tokens: int = 100 ) -> Tuple[str, List[Dict]]: """ Generate text and capture top-5 alternative tokens for each generated token. Returns: (generated_text, token_alternatives) """ # Prepare input messages = [{"role": "user", "content": input_text}] text = gen_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = gen_tokenizer(text, return_tensors="pt").to(device) # Generate with output_scores to get token probabilities with torch.no_grad(): outputs = gen_model.generate( **inputs, max_new_tokens=max_tokens, output_scores=True, return_dict_in_generate=True, do_sample=False, # Greedy decoding pad_token_id=gen_tokenizer.eos_token_id ) # Get generated tokens (excluding input) generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:] generated_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True) # Extract token alternatives from scores token_alternatives = [] if hasattr(outputs, 'scores') and outputs.scores: for score_tensor in outputs.scores: # Get probabilities probs = torch.nn.functional.softmax(score_tensor[0], dim=-1) # Get top 5 tokens top_probs, top_indices = torch.topk(probs, k=5) alternatives = [] for prob, idx in zip(top_probs, top_indices): token = gen_tokenizer.decode([idx.item()]) alternatives.append({ "token": token, "probability": f"{prob.item() * 100:.2f}%" }) token_alternatives.append(alternatives) return generated_text, token_alternatives def summarize_text_with_alternatives( input_text: str, max_tokens: int = 100 ) -> Tuple[str, List[Dict]]: """ Summarize text and capture top-5 alternative tokens for each generated token. Returns: (summary_text, token_alternatives) """ inputs = sum_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device) # Generate with output_scores with torch.no_grad(): outputs = sum_model.generate( **inputs, max_length=max_tokens, output_scores=True, return_dict_in_generate=True, do_sample=False, # Greedy decoding ) # Decode summary summary_text = sum_tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Extract token alternatives token_alternatives = [] if hasattr(outputs, 'scores') and outputs.scores: for score_tensor in outputs.scores: probs = torch.nn.functional.softmax(score_tensor[0], dim=-1) top_probs, top_indices = torch.topk(probs, k=5) alternatives = [] for prob, idx in zip(top_probs, top_indices): token = sum_tokenizer.decode([idx.item()]) alternatives.append({ "token": token, "probability": f"{prob.item() * 100:.2f}%" }) token_alternatives.append(alternatives) return summary_text, token_alternatives def create_html_with_tooltips(text: str, token_alternatives: List[Dict]) -> str: """ Create HTML with hoverable words that show token alternatives. """ if not token_alternatives: return f"
{text}
" # Split text into tokens/words for display words = text.split() html_parts = [] html_parts.append("""
""") # Map words to token alternatives (approximate mapping) alt_index = 0 for word in words: if alt_index < len(token_alternatives): alternatives = token_alternatives[alt_index] # Create tooltip content tooltip_html = "
" tooltip_html += "
Top 5 Alternatives:
" for i, alt in enumerate(alternatives, 1): tooltip_html += f"
" tooltip_html += f"{i}. {alt['token']}" tooltip_html += f"{alt['probability']}" tooltip_html += f"
" tooltip_html += "
" html_parts.append(f"{word}{tooltip_html}") alt_index += 1 else: html_parts.append(f"{word}") html_parts.append("
") return "".join(html_parts) def process_text(input_text: str, mode: str, max_tokens: int) -> Tuple[str, str]: """ Main processing function that handles both text generation and summarization. Returns: (result_html, status_message) """ if not input_text or not input_text.strip(): return "
Please enter some text to process.
", "❌ No input provided" # Check word count word_count = count_words(input_text) if word_count > 500: return f"
Input exceeds maximum limit of 500 words. Current: {word_count} words.
", f"❌ Input too long ({word_count} words)" try: if mode == "Text Generation": status = f"🔄 Generating text (max {max_tokens} tokens)..." generated_text, alternatives = generate_text_with_alternatives(input_text, max_tokens) result_html = create_html_with_tooltips(generated_text, alternatives) return result_html, f"✅ Generated {len(alternatives)} tokens" else: # Text Summarization status = f"🔄 Summarizing text (max {max_tokens} tokens)..." summary_text, alternatives = summarize_text_with_alternatives(input_text, max_tokens) result_html = create_html_with_tooltips(summary_text, alternatives) return result_html, f"✅ Generated {len(alternatives)} tokens" except Exception as e: error_msg = f"
Error: {str(e)}
" return error_msg, f"❌ Error: {str(e)}" # Create Gradio interface with gr.Blocks(title="AI Text Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🤖 AI Text Assistant Generate text or summarize articles using state-of-the-art AI models. **Hover over any word** in the result to see the top 5 alternative tokens the AI considered! """) with gr.Row(): with gr.Column(scale=2): mode = gr.Radio( choices=["Text Generation", "Text Summarization"], value="Text Generation", label="Mode", info="Choose between generating new text or summarizing existing text" ) input_text = gr.Textbox( label="Input Text", placeholder="Enter your text here... (max 500 words)", lines=6, max_lines=10 ) with gr.Row(): max_tokens = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Max Tokens", info="Maximum number of tokens to generate" ) process_btn = gr.Button("🚀 Process", variant="primary", size="lg") status = gr.Textbox(label="Status", interactive=False) with gr.Row(): output_html = gr.HTML(label="Result") gr.Markdown(""" ### 💡 Tips: - **Text Generation**: Provide a prompt and the AI will continue writing - **Text Summarization**: Paste an article or long text to get a concise summary - **Hover** over any word in the output to see what other words the AI considered - Models used: Qwen/Qwen2.5-0.5B-Instruct (generation) & facebook/bart-large-cnn (summarization) """) # Connect the button to the processing function process_btn.click( fn=process_text, inputs=[input_text, mode, max_tokens], outputs=[output_html, status] ) if __name__ == "__main__": demo.launch()