File size: 10,746 Bytes
01ae771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
DeepSeek Children's Stories Text Generation
Generate children's stories using the trained DeepSeek model
"""

import os
import sys
import argparse
import torch
import tiktoken
from typing import List, Optional

# Add the src directory to Python path
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from model.deepseek import DeepSeek, DeepSeekConfig

# Allowlist DeepSeekConfig for safe deserialization
torch.serialization.add_safe_globals([DeepSeekConfig])

class DeepSeekStoryGenerator:
    def __init__(self, model_path: str, device: str = 'auto'):
        """Initialize the story generator"""
        self.device = self._get_device(device)
        self.model = self._load_model(model_path)
        self.tokenizer = tiktoken.get_encoding("gpt2")
        
        # Special tokens for story structure
        self.special_tokens = {
            "story_start": "<|story|>",
            "story_end": "</|story|>",
            "prompt_start": "<|prompt|>",
            "prompt_end": "</|prompt|>",
            "moral_start": "<|moral|>",
            "moral_end": "</|moral|>",
            "character_start": "<|character|>",
            "character_end": "</|character|>"
        }
    
    def _get_device(self, device: str) -> str:
        """Get the appropriate device"""
        if device == 'auto':
            return 'cuda' if torch.cuda.is_available() else 'cpu'
        return device
    
    def _load_model(self, model_path: str) -> DeepSeek:
        """Load the trained model"""
        print(f"Loading model from {model_path}...")
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        
        # Create model with the same configuration
        config = checkpoint['config']
        model = DeepSeek(config)
        
        # Handle compiled model state dict by removing _orig_mod prefix
        state_dict = checkpoint['model']
        if all(k.startswith('_orig_mod.') for k in state_dict.keys()):
            state_dict = {k[10:]: v for k, v in state_dict.items()}  # Remove '_orig_mod.' prefix
        
        # Load model weights
        model.load_state_dict(state_dict)
        model.to(self.device)
        model.eval()
        
        print(f"Model loaded successfully!")
        print(f"Model configuration: {config.n_layer}L/{config.n_head}H/{config.n_embd}D")
        print(f"Device: {self.device}")
        
        return model
    
    def encode_prompt(self, prompt: str, character: Optional[str] = None) -> torch.Tensor:
        """Encode a prompt for generation"""
        # Create structured prompt
        full_prompt = f"{self.special_tokens['prompt_start']} {prompt.lower()} {self.special_tokens['prompt_end']}"
        
        if character:
            full_prompt += f" {self.special_tokens['character_start']} {character.lower()} {self.special_tokens['character_end']}"
        
        full_prompt += f" {self.special_tokens['story_start']}"
        
        # Tokenize
        token_ids = self.tokenizer.encode_ordinary(full_prompt)
        return torch.tensor([token_ids], dtype=torch.long, device=self.device)
    
    def generate_story(self, prompt: str, character: Optional[str] = None, 
                      max_tokens: int = 200, temperature: float = 0.8, 
                      top_k: int = 40, top_p: float = 0.9) -> str:
        """Generate a children's story"""
        print(f"Generating story for prompt: '{prompt}'")
        if character:
            print(f"Character: {character}")
        
        # Encode prompt
        input_ids = self.encode_prompt(prompt, character)
        
        # Generate
        with torch.no_grad():
            generated_ids = self.model.generate(
                input_ids,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_k=top_k
            )
        
        # Decode the generated text
        generated_text = self.tokenizer.decode(generated_ids[0].tolist())
        
        # Extract the story part
        story = self._extract_story(generated_text)
        
        return story
    
    def _extract_story(self, text: str) -> str:
        """Extract the story from the generated text"""
        # Find story start and end markers
        story_start = text.find(self.special_tokens['story_start'])
        story_end = text.find(self.special_tokens['story_end'])
        
        if story_start != -1 and story_end != -1:
            # Extract story content
            story_content = text[story_start + len(self.special_tokens['story_start']):story_end].strip()
            return story_content
        else:
            # Fallback: return the text after the last prompt
            prompt_end = text.find(self.special_tokens['prompt_end'])
            if prompt_end != -1:
                return text[prompt_end + len(self.special_tokens['prompt_end']):].strip()
            else:
                return text.strip()
    
    def generate_multiple_stories(self, prompts: List[str], num_stories: int = 3, 
                                **kwargs) -> List[str]:
        """Generate multiple stories from a list of prompts"""
        stories = []
        
        for i, prompt in enumerate(prompts):
            print(f"\nGenerating story {i+1}/{len(prompts)}...")
            story = self.generate_story(prompt, **kwargs)
            stories.append(story)
        
        return stories
    
    def interactive_generation(self):
        """Interactive story generation mode"""
        print("DeepSeek Children's Stories - Interactive Mode")
        print("Type 'quit' to exit")
        print("-" * 50)
        
        while True:
            try:
                # Get prompt from user
                prompt = input("\nEnter a story prompt: ").strip()
                
                if prompt.lower() in ['quit', 'exit', 'q']:
                    print("Goodbye!")
                    break
                
                if not prompt:
                    print("Please enter a valid prompt.")
                    continue
                
                # Get character (optional)
                character = input("Enter a character name (optional): ").strip()
                if not character:
                    character = None
                
                # Get generation parameters
                try:
                    max_tokens = int(input("Max tokens (default 200): ") or "200")
                    temperature = float(input("Temperature (default 0.8): ") or "0.8")
                except ValueError:
                    max_tokens = 200
                    temperature = 0.8
                
                # Generate story
                story = self.generate_story(
                    prompt, 
                    character=character,
                    max_tokens=max_tokens,
                    temperature=temperature
                )
                
                # Display story
                print("\n" + "="*50)
                print("GENERATED STORY:")
                print("="*50)
                print(story)
                print("="*50)
                
            except KeyboardInterrupt:
                print("\nGoodbye!")
                break
            except Exception as e:
                print(f"Error generating story: {e}")


def main():
    """Main generation function"""
    parser = argparse.ArgumentParser(description='Generate children\'s stories with DeepSeek')
    
    # Model configuration
    parser.add_argument('--model-path', type=str, default='checkpoints/best_model.pt',
                       help='Path to the trained model checkpoint')
    parser.add_argument('--device', type=str, default='auto',
                       help='Device to use (auto, cuda, cpu)')
    
    # Generation parameters
    parser.add_argument('--prompt', type=str, help='Story prompt')
    parser.add_argument('--character', type=str, help='Character name')
    parser.add_argument('--max-tokens', type=int, default=200, help='Maximum tokens to generate')
    parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature')
    parser.add_argument('--top-k', type=int, default=40, help='Top-k sampling')
    parser.add_argument('--top-p', type=float, default=0.9, help='Top-p sampling')
    
    # Multiple generation
    parser.add_argument('--num-stories', type=int, default=1, help='Number of stories to generate')
    parser.add_argument('--interactive', action='store_true', help='Interactive mode')
    
    args = parser.parse_args()
    
    # Check if model exists
    if not os.path.exists(args.model_path):
        print(f"Error: Model file not found at {args.model_path}")
        print("Please train the model first or specify the correct path.")
        return
    
    # Create generator
    generator = DeepSeekStoryGenerator(args.model_path, args.device)
    
    if args.interactive:
        # Interactive mode
        generator.interactive_generation()
    else:
        # Single or multiple generation
        if args.prompt:
            if args.num_stories == 1:
                # Single story
                story = generator.generate_story(
                    args.prompt,
                    character=args.character,
                    max_tokens=args.max_tokens,
                    temperature=args.temperature,
                    top_k=args.top_k,
                    top_p=args.top_p
                )
                
                print(f"\nPrompt: {args.prompt}")
                if args.character:
                    print(f"Character: {args.character}")
                print("\n" + "="*50)
                print("GENERATED STORY:")
                print("="*50)
                print(story)
                print("="*50)
            else:
                # Multiple stories
                prompts = [args.prompt] * args.num_stories
                stories = generator.generate_multiple_stories(
                    prompts,
                    num_stories=args.num_stories,
                    character=args.character,
                    max_tokens=args.max_tokens,
                    temperature=args.temperature,
                    top_k=args.top_k,
                    top_p=args.top_p
                )
                
                for i, story in enumerate(stories):
                    print(f"\nStory {i+1}:")
                    print("="*50)
                    print(story)
                    print("="*50)
        else:
            print("Please provide a prompt or use --interactive mode.")
            print("Example: python generate.py --prompt 'A brave little mouse' --character 'Mickey'")


if __name__ == "__main__":
    main()