| | import os |
| | import gradio as gr |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from tokenizers import Tokenizer |
| | import json |
| | import math |
| | import requests |
| | from tqdm import tqdm |
| |
|
| | |
| | TOKENIZER_FILE = "20B_tokenizer.json" |
| | TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json" |
| |
|
| | def download_file(url, filename): |
| | if not os.path.exists(filename): |
| | print(f"Downloading {filename}...") |
| | response = requests.get(url, stream=True) |
| | total_size = int(response.headers.get('content-length', 0)) |
| | |
| | with open(filename, 'wb') as file, tqdm( |
| | desc=filename, |
| | total=total_size, |
| | unit='iB', |
| | unit_scale=True, |
| | unit_divisor=1024, |
| | ) as pbar: |
| | for data in response.iter_content(chunk_size=1024): |
| | size = file.write(data) |
| | pbar.update(size) |
| |
|
| | |
| | if not os.path.exists(TOKENIZER_FILE): |
| | download_file(TOKENIZER_URL, TOKENIZER_FILE) |
| |
|
| | tokenizer = Tokenizer.from_file(TOKENIZER_FILE) |
| |
|
| | class RWKV_Model: |
| | def __init__(self, model_path): |
| | self.model_path = model_path |
| | self.model = None |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | def load_model(self): |
| | if not os.path.exists(self.model_path): |
| | raise FileNotFoundError(f"Model file {self.model_path} not found") |
| | |
| | self.model = torch.load(self.model_path, map_location=self.device) |
| | print("Model loaded successfully") |
| | |
| | def generate(self, prompt, max_length=100, temperature=1.0, top_p=0.9): |
| | if self.model is None: |
| | self.load_model() |
| | |
| | input_ids = tokenizer.encode(prompt).ids |
| | input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device) |
| | |
| | with torch.no_grad(): |
| | output_sequence = [] |
| | |
| | for _ in range(max_length): |
| | outputs = self.model(input_tensor) |
| | next_token_logits = outputs[0, -1, :] / temperature |
| | |
| | |
| | sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| | indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| | next_token_logits[indices_to_remove] = float('-inf') |
| | |
| | probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | |
| | output_sequence.append(next_token.item()) |
| | input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1) |
| | |
| | if next_token.item() == tokenizer.token_to_id("</s>"): |
| | break |
| | |
| | return tokenizer.decode(output_sequence) |
| |
|
| | def generate_text( |
| | prompt, |
| | temperature=1.0, |
| | top_p=0.9, |
| | max_length=100, |
| | model_size="small" |
| | ): |
| | try: |
| | |
| | model_path = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth" if model_size == "small" else "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth" |
| | |
| | model = RWKV_Model(model_path) |
| | |
| | generated_text = model.generate( |
| | prompt=prompt, |
| | max_length=max_length, |
| | temperature=temperature, |
| | top_p=top_p |
| | ) |
| | |
| | return generated_text |
| | |
| | except Exception as e: |
| | return f"Error: {str(e)}" |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# RWKV-7 Text Generation Demo") |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | prompt_input = gr.Textbox( |
| | label="Input Prompt", |
| | placeholder="Enter your prompt here...", |
| | lines=5 |
| | ) |
| | model_size = gr.Radio( |
| | choices=["small", "large"], |
| | label="Model Size", |
| | value="small" |
| | ) |
| | |
| | with gr.Column(): |
| | temperature_slider = gr.Slider( |
| | minimum=0.1, |
| | maximum=2.0, |
| | value=1.0, |
| | label="Temperature" |
| | ) |
| | top_p_slider = gr.Slider( |
| | minimum=0.1, |
| | maximum=1.0, |
| | value=0.9, |
| | label="Top-p" |
| | ) |
| | max_length_slider = gr.Slider( |
| | minimum=10, |
| | maximum=500, |
| | value=100, |
| | step=10, |
| | label="Maximum Length" |
| | ) |
| | |
| | generate_button = gr.Button("Generate") |
| | output_text = gr.Textbox(label="Generated Output", lines=10) |
| | |
| | generate_button.click( |
| | fn=generate_text, |
| | inputs=[ |
| | prompt_input, |
| | temperature_slider, |
| | top_p_slider, |
| | max_length_slider, |
| | model_size |
| | ], |
| | outputs=output_text |
| | ) |
| | |
| | gr.Markdown(""" |
| | ## Parameters: |
| | - **Temperature**: Controls randomness (higher = more random) |
| | - **Top-p**: Controls diversity (higher = more diverse) |
| | - **Maximum Length**: Maximum number of tokens to generate |
| | - **Model Size**: |
| | - Small (0.1B parameters) |
| | - Large (0.4B parameters) |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |