Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer | |
| from model import SmolLM2ForCausalLM, SmolLM2Config | |
| import os | |
| import json | |
| import numpy as np | |
| from typing import Dict, Any | |
| class SmolLM2Generator: | |
| def __init__(self, model_path: str = "smollm2_model_final"): | |
| """Initialize the model and tokenizer""" | |
| self.device = self._get_device() | |
| print(f"Using device: {self.device}") | |
| self.model, self.tokenizer = self._load_model(model_path) | |
| print("Model loaded successfully!") | |
| def _get_device(self) -> torch.device: | |
| """Get the best available device""" | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| def _load_model(self, model_path: str): | |
| """Load model and tokenizer from saved files""" | |
| # Load tokenizer | |
| tokenizer_path = os.path.join(model_path, "tokenizer") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load config | |
| config_path = os.path.join(model_path, "config.json") | |
| with open(config_path, 'r') as f: | |
| config_dict = json.load(f) | |
| config = SmolLM2Config(**config_dict) | |
| # Initialize and load model | |
| model = SmolLM2ForCausalLM(config) | |
| model_path = os.path.join(model_path, "pytorch_model.bin") | |
| # Use weights_only=True for safer model loading | |
| state_dict = torch.load( | |
| model_path, | |
| map_location=self.device, | |
| weights_only=True # Safer loading option | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.to(self.device) | |
| model.eval() | |
| return model, tokenizer | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_length: int = 100, | |
| temperature: float = 0.8, | |
| top_k: int = 50, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.5, | |
| **kwargs | |
| ) -> str: | |
| """Generate text with advanced parameters""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(self.device) | |
| attention_mask = inputs["attention_mask"].to(self.device) | |
| # Track generated tokens for repetition penalty | |
| generated_tokens = [] | |
| current_length = input_ids.size(1) | |
| with torch.no_grad(): | |
| while current_length < max_length: | |
| outputs, _ = self.model(input_ids, attention_mask=attention_mask) | |
| next_token_logits = outputs[:, -1, :] / temperature | |
| # Apply repetition penalty | |
| if repetition_penalty != 1.0: | |
| for token in generated_tokens: | |
| next_token_logits[0, token] /= repetition_penalty | |
| # Filter special tokens | |
| for special_token_id in [self.tokenizer.pad_token_id, | |
| self.tokenizer.eos_token_id, | |
| self.tokenizer.bos_token_id]: | |
| if special_token_id is not None: | |
| next_token_logits[0, special_token_id] = float('-inf') | |
| # Apply top-k filtering | |
| if top_k > 0: | |
| indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] | |
| next_token_logits[indices_to_remove] = float('-inf') | |
| # Apply top-p (nucleus) filtering | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) | |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| next_token_logits[0, indices_to_remove] = float('-inf') | |
| # Sample from the filtered distribution | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| next_token = torch.multinomial(probs[0], num_samples=1) | |
| generated_tokens.append(next_token.item()) | |
| input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) | |
| attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=self.device)], dim=1) | |
| current_length += 1 | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| return generated_text | |
| # Initialize the generator | |
| generator = SmolLM2Generator() | |
| def generate_text( | |
| prompt: str, | |
| max_length: int = 100, | |
| temperature: float = 0.8, | |
| top_k: int = 50, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.5 | |
| ) -> tuple[str, str]: | |
| """Gradio interface function""" | |
| try: | |
| generated_text = generator.generate( | |
| prompt=prompt, | |
| max_length=max_length, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty | |
| ) | |
| return f"{prompt}{generated_text}", "Status: ✓ Generation successful" | |
| except Exception as e: | |
| return f"Error: {str(e)}", "Status: ⚠ Generation failed" | |
| # Create Gradio interface | |
| with gr.Blocks(title="SmolLM2 Text Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🤖 SmolLM2 Text Generator | |
| A lightweight language model trained on diverse text data. Experiment with different parameters to control the text generation! [GitHub repo](https://github.com/dhairyag/SmolLM2_GroundUp.git) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3 | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Generation Parameters") | |
| max_length = gr.Slider( | |
| minimum=10, maximum=500, value=100, | |
| label="Maximum Length", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=0.8, | |
| label="Temperature", | |
| info="Higher values make output more random, lower values more deterministic" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=0, maximum=100, value=50, | |
| label="Top-k", | |
| info="Limit vocabulary to k most likely tokens" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.9, | |
| label="Top-p (Nucleus Sampling)", | |
| info="Limit cumulative probability of tokens to sample from" | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, maximum=2.0, value=1.5, | |
| label="Repetition Penalty", | |
| info="Penalize repeated tokens" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="Generated Text", | |
| lines=5, | |
| show_copy_button=True | |
| ) | |
| status = gr.Markdown("Status: *Waiting for input...*") | |
| # Handle generation | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt, max_length, temperature, top_k, top_p, repetition_penalty], | |
| outputs=[output, status] | |
| ) | |
| gr.Markdown(""" | |
| ### Tips | |
| - Try different temperatures to control randomness | |
| - Adjust top-k and top-p for better quality | |
| - Use repetition penalty to avoid loops | |
| ### Example Prompts | |
| - "Once upon a time in a distant galaxy..." | |
| - "The scientific method consists of..." | |
| - "The main difference between classical and quantum physics is..." | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |