Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
| import json | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| # Global variables for models | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Model names | |
| TEXT_GEN_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" | |
| SUMMARIZATION_MODEL = "facebook/bart-large-cnn" | |
| # Load models and tokenizers | |
| print("Loading models...") | |
| gen_tokenizer = AutoTokenizer.from_pretrained(TEXT_GEN_MODEL) | |
| gen_model = AutoModelForCausalLM.from_pretrained(TEXT_GEN_MODEL).to(device) | |
| sum_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL) | |
| sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZATION_MODEL).to(device) | |
| print("Models loaded successfully!") | |
| def count_words(text: str) -> int: | |
| """Count words in text""" | |
| return len(text.split()) | |
| def generate_text_with_alternatives( | |
| input_text: str, | |
| max_tokens: int = 100 | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Generate text and capture top-5 alternative tokens for each generated token. | |
| Returns: (generated_text, token_alternatives) | |
| """ | |
| # Prepare input | |
| messages = [{"role": "user", "content": input_text}] | |
| text = gen_tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = gen_tokenizer(text, return_tensors="pt").to(device) | |
| # Generate with output_scores to get token probabilities | |
| with torch.no_grad(): | |
| outputs = gen_model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| do_sample=False, # Greedy decoding | |
| pad_token_id=gen_tokenizer.eos_token_id | |
| ) | |
| # Get generated tokens (excluding input) | |
| generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:] | |
| generated_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Extract token alternatives from scores | |
| token_alternatives = [] | |
| if hasattr(outputs, 'scores') and outputs.scores: | |
| for score_tensor in outputs.scores: | |
| # Get probabilities | |
| probs = torch.nn.functional.softmax(score_tensor[0], dim=-1) | |
| # Get top 5 tokens | |
| top_probs, top_indices = torch.topk(probs, k=5) | |
| alternatives = [] | |
| for prob, idx in zip(top_probs, top_indices): | |
| token = gen_tokenizer.decode([idx.item()]) | |
| alternatives.append({ | |
| "token": token, | |
| "probability": f"{prob.item() * 100:.2f}%" | |
| }) | |
| token_alternatives.append(alternatives) | |
| return generated_text, token_alternatives | |
| def summarize_text_with_alternatives( | |
| input_text: str, | |
| max_tokens: int = 100 | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Summarize text and capture top-5 alternative tokens for each generated token. | |
| Returns: (summary_text, token_alternatives) | |
| """ | |
| inputs = sum_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device) | |
| # Generate with output_scores | |
| with torch.no_grad(): | |
| outputs = sum_model.generate( | |
| **inputs, | |
| max_length=max_tokens, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| do_sample=False, # Greedy decoding | |
| ) | |
| # Decode summary | |
| summary_text = sum_tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) | |
| # Extract token alternatives | |
| token_alternatives = [] | |
| if hasattr(outputs, 'scores') and outputs.scores: | |
| for score_tensor in outputs.scores: | |
| probs = torch.nn.functional.softmax(score_tensor[0], dim=-1) | |
| top_probs, top_indices = torch.topk(probs, k=5) | |
| alternatives = [] | |
| for prob, idx in zip(top_probs, top_indices): | |
| token = sum_tokenizer.decode([idx.item()]) | |
| alternatives.append({ | |
| "token": token, | |
| "probability": f"{prob.item() * 100:.2f}%" | |
| }) | |
| token_alternatives.append(alternatives) | |
| return summary_text, token_alternatives | |
| def create_html_with_tooltips(text: str, token_alternatives: List[Dict]) -> str: | |
| """ | |
| Create HTML with hoverable words that show token alternatives. | |
| """ | |
| if not token_alternatives: | |
| return f"<div style='padding: 20px; font-size: 16px;'>{text}</div>" | |
| # Split text into tokens/words for display | |
| words = text.split() | |
| html_parts = [] | |
| html_parts.append(""" | |
| <style> | |
| .word-container { | |
| display: inline-block; | |
| position: relative; | |
| margin: 2px; | |
| padding: 2px 4px; | |
| cursor: pointer; | |
| border-radius: 3px; | |
| transition: background-color 0.2s; | |
| } | |
| .word-container:hover { | |
| background-color: #e3f2fd; | |
| } | |
| .tooltip { | |
| visibility: hidden; | |
| position: absolute; | |
| z-index: 1000; | |
| background-color: #263238; | |
| color: white; | |
| padding: 12px; | |
| border-radius: 6px; | |
| font-size: 13px; | |
| min-width: 250px; | |
| bottom: 125%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.3); | |
| opacity: 0; | |
| transition: opacity 0.3s; | |
| } | |
| .tooltip::after { | |
| content: ""; | |
| position: absolute; | |
| top: 100%; | |
| left: 50%; | |
| margin-left: -5px; | |
| border-width: 5px; | |
| border-style: solid; | |
| border-color: #263238 transparent transparent transparent; | |
| } | |
| .word-container:hover .tooltip { | |
| visibility: visible; | |
| opacity: 1; | |
| } | |
| .alternative-item { | |
| padding: 4px 0; | |
| border-bottom: 1px solid #37474f; | |
| } | |
| .alternative-item:last-child { | |
| border-bottom: none; | |
| } | |
| .token-text { | |
| font-weight: bold; | |
| color: #81d4fa; | |
| } | |
| .probability { | |
| float: right; | |
| color: #a5d6a7; | |
| } | |
| .result-container { | |
| padding: 20px; | |
| font-size: 16px; | |
| line-height: 1.8; | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; | |
| } | |
| </style> | |
| <div class='result-container'> | |
| """) | |
| # Map words to token alternatives (approximate mapping) | |
| alt_index = 0 | |
| for word in words: | |
| if alt_index < len(token_alternatives): | |
| alternatives = token_alternatives[alt_index] | |
| # Create tooltip content | |
| tooltip_html = "<div class='tooltip'>" | |
| tooltip_html += "<div style='margin-bottom: 8px; font-weight: bold; border-bottom: 2px solid #37474f; padding-bottom: 4px;'>Top 5 Alternatives:</div>" | |
| for i, alt in enumerate(alternatives, 1): | |
| tooltip_html += f"<div class='alternative-item'>" | |
| tooltip_html += f"<span>{i}. <span class='token-text'>{alt['token']}</span></span>" | |
| tooltip_html += f"<span class='probability'>{alt['probability']}</span>" | |
| tooltip_html += f"</div>" | |
| tooltip_html += "</div>" | |
| html_parts.append(f"<span class='word-container'>{word}{tooltip_html}</span>") | |
| alt_index += 1 | |
| else: | |
| html_parts.append(f"<span class='word-container'>{word}</span>") | |
| html_parts.append("</div>") | |
| return "".join(html_parts) | |
| def process_text(input_text: str, mode: str, max_tokens: int) -> Tuple[str, str]: | |
| """ | |
| Main processing function that handles both text generation and summarization. | |
| Returns: (result_html, status_message) | |
| """ | |
| if not input_text or not input_text.strip(): | |
| return "<div style='padding: 20px; color: red;'>Please enter some text to process.</div>", "β No input provided" | |
| # Check word count | |
| word_count = count_words(input_text) | |
| if word_count > 500: | |
| return f"<div style='padding: 20px; color: red;'>Input exceeds maximum limit of 500 words. Current: {word_count} words.</div>", f"β Input too long ({word_count} words)" | |
| try: | |
| if mode == "Text Generation": | |
| status = f"π Generating text (max {max_tokens} tokens)..." | |
| generated_text, alternatives = generate_text_with_alternatives(input_text, max_tokens) | |
| result_html = create_html_with_tooltips(generated_text, alternatives) | |
| return result_html, f"β Generated {len(alternatives)} tokens" | |
| else: # Text Summarization | |
| status = f"π Summarizing text (max {max_tokens} tokens)..." | |
| summary_text, alternatives = summarize_text_with_alternatives(input_text, max_tokens) | |
| result_html = create_html_with_tooltips(summary_text, alternatives) | |
| return result_html, f"β Generated {len(alternatives)} tokens" | |
| except Exception as e: | |
| error_msg = f"<div style='padding: 20px; color: red;'>Error: {str(e)}</div>" | |
| return error_msg, f"β Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="AI Text Assistant", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π€ AI Text Assistant | |
| Generate text or summarize articles using state-of-the-art AI models. | |
| **Hover over any word** in the result to see the top 5 alternative tokens the AI considered! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| mode = gr.Radio( | |
| choices=["Text Generation", "Text Summarization"], | |
| value="Text Generation", | |
| label="Mode", | |
| info="Choose between generating new text or summarizing existing text" | |
| ) | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter your text here... (max 500 words)", | |
| lines=6, | |
| max_lines=10 | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| minimum=10, | |
| maximum=500, | |
| value=100, | |
| step=10, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| process_btn = gr.Button("π Process", variant="primary", size="lg") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Row(): | |
| output_html = gr.HTML(label="Result") | |
| gr.Markdown(""" | |
| ### π‘ Tips: | |
| - **Text Generation**: Provide a prompt and the AI will continue writing | |
| - **Text Summarization**: Paste an article or long text to get a concise summary | |
| - **Hover** over any word in the output to see what other words the AI considered | |
| - Models used: Qwen/Qwen2.5-0.5B-Instruct (generation) & facebook/bart-large-cnn (summarization) | |
| """) | |
| # Connect the button to the processing function | |
| process_btn.click( | |
| fn=process_text, | |
| inputs=[input_text, mode, max_tokens], | |
| outputs=[output_html, status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |