| | """ |
| | DeepSeek Children's Stories Text Generation |
| | Generate children's stories using the trained DeepSeek model |
| | """ |
| |
|
| | import os |
| | import sys |
| | import argparse |
| | import torch |
| | import tiktoken |
| | from typing import List, Optional |
| |
|
| | |
| | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) |
| |
|
| | from model.deepseek import DeepSeek, DeepSeekConfig |
| |
|
| | |
| | torch.serialization.add_safe_globals([DeepSeekConfig]) |
| |
|
| | class DeepSeekStoryGenerator: |
| | def __init__(self, model_path: str, device: str = 'auto'): |
| | """Initialize the story generator""" |
| | self.device = self._get_device(device) |
| | self.model = self._load_model(model_path) |
| | self.tokenizer = tiktoken.get_encoding("gpt2") |
| | |
| | |
| | self.special_tokens = { |
| | "story_start": "<|story|>", |
| | "story_end": "</|story|>", |
| | "prompt_start": "<|prompt|>", |
| | "prompt_end": "</|prompt|>", |
| | "moral_start": "<|moral|>", |
| | "moral_end": "</|moral|>", |
| | "character_start": "<|character|>", |
| | "character_end": "</|character|>" |
| | } |
| | |
| | def _get_device(self, device: str) -> str: |
| | """Get the appropriate device""" |
| | if device == 'auto': |
| | return 'cuda' if torch.cuda.is_available() else 'cpu' |
| | return device |
| | |
| | def _load_model(self, model_path: str) -> DeepSeek: |
| | """Load the trained model""" |
| | print(f"Loading model from {model_path}...") |
| | |
| | |
| | checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| | |
| | |
| | config = checkpoint['config'] |
| | model = DeepSeek(config) |
| | |
| | |
| | state_dict = checkpoint['model'] |
| | if all(k.startswith('_orig_mod.') for k in state_dict.keys()): |
| | state_dict = {k[10:]: v for k, v in state_dict.items()} |
| | |
| | |
| | model.load_state_dict(state_dict) |
| | model.to(self.device) |
| | model.eval() |
| | |
| | print(f"Model loaded successfully!") |
| | print(f"Model configuration: {config.n_layer}L/{config.n_head}H/{config.n_embd}D") |
| | print(f"Device: {self.device}") |
| | |
| | return model |
| | |
| | def encode_prompt(self, prompt: str, character: Optional[str] = None) -> torch.Tensor: |
| | """Encode a prompt for generation""" |
| | |
| | full_prompt = f"{self.special_tokens['prompt_start']} {prompt.lower()} {self.special_tokens['prompt_end']}" |
| | |
| | if character: |
| | full_prompt += f" {self.special_tokens['character_start']} {character.lower()} {self.special_tokens['character_end']}" |
| | |
| | full_prompt += f" {self.special_tokens['story_start']}" |
| | |
| | |
| | token_ids = self.tokenizer.encode_ordinary(full_prompt) |
| | return torch.tensor([token_ids], dtype=torch.long, device=self.device) |
| | |
| | def generate_story(self, prompt: str, character: Optional[str] = None, |
| | max_tokens: int = 200, temperature: float = 0.8, |
| | top_k: int = 40, top_p: float = 0.9) -> str: |
| | """Generate a children's story""" |
| | print(f"Generating story for prompt: '{prompt}'") |
| | if character: |
| | print(f"Character: {character}") |
| | |
| | |
| | input_ids = self.encode_prompt(prompt, character) |
| | |
| | |
| | with torch.no_grad(): |
| | generated_ids = self.model.generate( |
| | input_ids, |
| | max_new_tokens=max_tokens, |
| | temperature=temperature, |
| | top_k=top_k |
| | ) |
| | |
| | |
| | generated_text = self.tokenizer.decode(generated_ids[0].tolist()) |
| | |
| | |
| | story = self._extract_story(generated_text) |
| | |
| | return story |
| | |
| | def _extract_story(self, text: str) -> str: |
| | """Extract the story from the generated text""" |
| | |
| | story_start = text.find(self.special_tokens['story_start']) |
| | story_end = text.find(self.special_tokens['story_end']) |
| | |
| | if story_start != -1 and story_end != -1: |
| | |
| | story_content = text[story_start + len(self.special_tokens['story_start']):story_end].strip() |
| | return story_content |
| | else: |
| | |
| | prompt_end = text.find(self.special_tokens['prompt_end']) |
| | if prompt_end != -1: |
| | return text[prompt_end + len(self.special_tokens['prompt_end']):].strip() |
| | else: |
| | return text.strip() |
| | |
| | def generate_multiple_stories(self, prompts: List[str], num_stories: int = 3, |
| | **kwargs) -> List[str]: |
| | """Generate multiple stories from a list of prompts""" |
| | stories = [] |
| | |
| | for i, prompt in enumerate(prompts): |
| | print(f"\nGenerating story {i+1}/{len(prompts)}...") |
| | story = self.generate_story(prompt, **kwargs) |
| | stories.append(story) |
| | |
| | return stories |
| | |
| | def interactive_generation(self): |
| | """Interactive story generation mode""" |
| | print("DeepSeek Children's Stories - Interactive Mode") |
| | print("Type 'quit' to exit") |
| | print("-" * 50) |
| | |
| | while True: |
| | try: |
| | |
| | prompt = input("\nEnter a story prompt: ").strip() |
| | |
| | if prompt.lower() in ['quit', 'exit', 'q']: |
| | print("Goodbye!") |
| | break |
| | |
| | if not prompt: |
| | print("Please enter a valid prompt.") |
| | continue |
| | |
| | |
| | character = input("Enter a character name (optional): ").strip() |
| | if not character: |
| | character = None |
| | |
| | |
| | try: |
| | max_tokens = int(input("Max tokens (default 200): ") or "200") |
| | temperature = float(input("Temperature (default 0.8): ") or "0.8") |
| | except ValueError: |
| | max_tokens = 200 |
| | temperature = 0.8 |
| | |
| | |
| | story = self.generate_story( |
| | prompt, |
| | character=character, |
| | max_tokens=max_tokens, |
| | temperature=temperature |
| | ) |
| | |
| | |
| | print("\n" + "="*50) |
| | print("GENERATED STORY:") |
| | print("="*50) |
| | print(story) |
| | print("="*50) |
| | |
| | except KeyboardInterrupt: |
| | print("\nGoodbye!") |
| | break |
| | except Exception as e: |
| | print(f"Error generating story: {e}") |
| |
|
| |
|
| | def main(): |
| | """Main generation function""" |
| | parser = argparse.ArgumentParser(description='Generate children\'s stories with DeepSeek') |
| | |
| | |
| | parser.add_argument('--model-path', type=str, default='checkpoints/best_model.pt', |
| | help='Path to the trained model checkpoint') |
| | parser.add_argument('--device', type=str, default='auto', |
| | help='Device to use (auto, cuda, cpu)') |
| | |
| | |
| | parser.add_argument('--prompt', type=str, help='Story prompt') |
| | parser.add_argument('--character', type=str, help='Character name') |
| | parser.add_argument('--max-tokens', type=int, default=200, help='Maximum tokens to generate') |
| | parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature') |
| | parser.add_argument('--top-k', type=int, default=40, help='Top-k sampling') |
| | parser.add_argument('--top-p', type=float, default=0.9, help='Top-p sampling') |
| | |
| | |
| | parser.add_argument('--num-stories', type=int, default=1, help='Number of stories to generate') |
| | parser.add_argument('--interactive', action='store_true', help='Interactive mode') |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if not os.path.exists(args.model_path): |
| | print(f"Error: Model file not found at {args.model_path}") |
| | print("Please train the model first or specify the correct path.") |
| | return |
| | |
| | |
| | generator = DeepSeekStoryGenerator(args.model_path, args.device) |
| | |
| | if args.interactive: |
| | |
| | generator.interactive_generation() |
| | else: |
| | |
| | if args.prompt: |
| | if args.num_stories == 1: |
| | |
| | story = generator.generate_story( |
| | args.prompt, |
| | character=args.character, |
| | max_tokens=args.max_tokens, |
| | temperature=args.temperature, |
| | top_k=args.top_k, |
| | top_p=args.top_p |
| | ) |
| | |
| | print(f"\nPrompt: {args.prompt}") |
| | if args.character: |
| | print(f"Character: {args.character}") |
| | print("\n" + "="*50) |
| | print("GENERATED STORY:") |
| | print("="*50) |
| | print(story) |
| | print("="*50) |
| | else: |
| | |
| | prompts = [args.prompt] * args.num_stories |
| | stories = generator.generate_multiple_stories( |
| | prompts, |
| | num_stories=args.num_stories, |
| | character=args.character, |
| | max_tokens=args.max_tokens, |
| | temperature=args.temperature, |
| | top_k=args.top_k, |
| | top_p=args.top_p |
| | ) |
| | |
| | for i, story in enumerate(stories): |
| | print(f"\nStory {i+1}:") |
| | print("="*50) |
| | print(story) |
| | print("="*50) |
| | else: |
| | print("Please provide a prompt or use --interactive mode.") |
| | print("Example: python generate.py --prompt 'A brave little mouse' --character 'Mickey'") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|