import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM import json from typing import Dict, List, Tuple import numpy as np # Global variables for models device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Model names TEXT_GEN_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" SUMMARIZATION_MODEL = "facebook/bart-large-cnn" # Load models and tokenizers print("Loading models...") gen_tokenizer = AutoTokenizer.from_pretrained(TEXT_GEN_MODEL) gen_model = AutoModelForCausalLM.from_pretrained(TEXT_GEN_MODEL).to(device) sum_tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL) sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZATION_MODEL).to(device) print("Models loaded successfully!") def count_words(text: str) -> int: """Count words in text""" return len(text.split()) def generate_text_with_alternatives( input_text: str, max_tokens: int = 100 ) -> Tuple[str, List[Dict]]: """ Generate text and capture top-5 alternative tokens for each generated token. Returns: (generated_text, token_alternatives) """ # Prepare input messages = [{"role": "user", "content": input_text}] text = gen_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = gen_tokenizer(text, return_tensors="pt").to(device) # Generate with output_scores to get token probabilities with torch.no_grad(): outputs = gen_model.generate( **inputs, max_new_tokens=max_tokens, output_scores=True, return_dict_in_generate=True, do_sample=False, # Greedy decoding pad_token_id=gen_tokenizer.eos_token_id ) # Get generated tokens (excluding input) generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:] generated_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True) # Extract token alternatives from scores token_alternatives = [] if hasattr(outputs, 'scores') and outputs.scores: for score_tensor in outputs.scores: # Get probabilities probs = torch.nn.functional.softmax(score_tensor[0], dim=-1) # Get top 5 tokens top_probs, top_indices = torch.topk(probs, k=5) alternatives = [] for prob, idx in zip(top_probs, top_indices): token = gen_tokenizer.decode([idx.item()]) alternatives.append({ "token": token, "probability": f"{prob.item() * 100:.2f}%" }) token_alternatives.append(alternatives) return generated_text, token_alternatives def summarize_text_with_alternatives( input_text: str, max_tokens: int = 100 ) -> Tuple[str, List[Dict]]: """ Summarize text and capture top-5 alternative tokens for each generated token. Returns: (summary_text, token_alternatives) """ inputs = sum_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device) # Generate with output_scores with torch.no_grad(): outputs = sum_model.generate( **inputs, max_length=max_tokens, output_scores=True, return_dict_in_generate=True, do_sample=False, # Greedy decoding ) # Decode summary summary_text = sum_tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Extract token alternatives token_alternatives = [] if hasattr(outputs, 'scores') and outputs.scores: for score_tensor in outputs.scores: probs = torch.nn.functional.softmax(score_tensor[0], dim=-1) top_probs, top_indices = torch.topk(probs, k=5) alternatives = [] for prob, idx in zip(top_probs, top_indices): token = sum_tokenizer.decode([idx.item()]) alternatives.append({ "token": token, "probability": f"{prob.item() * 100:.2f}%" }) token_alternatives.append(alternatives) return summary_text, token_alternatives def create_html_with_tooltips(text: str, token_alternatives: List[Dict]) -> str: """ Create HTML with hoverable words that show token alternatives. """ if not token_alternatives: return f"