Ronakparmar commited on
Commit
1f6ce96
·
verified ·
1 Parent(s): e30eab8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +143 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ from transformers import GPT2Tokenizer
5
+ from safetensors.torch import load_file
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from dataclasses import dataclass
9
+
10
+ # [Keep all your model code (CausalSelfAttention, MLP, Block, GPTConfig, GPT classes) as is]
11
+
12
+ # Define the GPTConfig and GPT classes (same as your original code)
13
+ # ...
14
+
15
+ # Initialize global variables
16
+ model = None
17
+ tokenizer = None
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ def load_model():
21
+ """Load the Leap0 model and tokenizer."""
22
+ global model, tokenizer
23
+
24
+ try:
25
+ # Paths to config and model files
26
+ config_path = "config.json"
27
+ model_path = "model.safetensors"
28
+
29
+ print(f"Loading configuration from {config_path}...")
30
+ # Load the configuration
31
+ with open(config_path, "r") as f:
32
+ config_dict = json.load(f)
33
+
34
+ print("Configuration loaded. Creating model config...")
35
+ config = GPTConfig.from_dict(config_dict)
36
+ print(f"Model config created: {config}")
37
+
38
+ print(f"Loading model weights from {model_path}...")
39
+ # Load the model weights
40
+ tensors = load_file(model_path)
41
+
42
+ print("Instantiating model...")
43
+ # Instantiate the model with the loaded config
44
+ model = GPT(config)
45
+
46
+ print("Loading weights into model...")
47
+ model.load_state_dict(tensors, strict=False)
48
+ model.to(device)
49
+ model.eval()
50
+
51
+ print("Loading tokenizer...")
52
+ # Load the tokenizer
53
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
54
+
55
+ print("Model and tokenizer loaded successfully")
56
+ except Exception as e:
57
+ print(f"Error loading model: {str(e)}")
58
+ raise
59
+
60
+ def generate_text(prompt, max_length=50, temperature=0.7, top_k=40):
61
+ """Generate text based on the provided prompt."""
62
+ if model is None or tokenizer is None:
63
+ load_model()
64
+
65
+ # Tokenize the input text
66
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
67
+
68
+ # Generate text
69
+ with torch.no_grad():
70
+ output_ids = model.generate(
71
+ input_ids,
72
+ max_new_tokens=max_length,
73
+ temperature=temperature,
74
+ top_k=top_k
75
+ )
76
+
77
+ # Decode the output
78
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
79
+
80
+ return output_text
81
+
82
+ # Create the Gradio interface
83
+ def create_interface():
84
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
85
+ gr.Markdown("# Leap0 Language Model")
86
+ gr.Markdown("A GPT-2 based model trained on the Tiny Stories dataset")
87
+
88
+ with gr.Row():
89
+ with gr.Column():
90
+ prompt = gr.Textbox(
91
+ label="Enter your prompt",
92
+ placeholder="once upon a time in the village of",
93
+ lines=3
94
+ )
95
+
96
+ with gr.Row():
97
+ max_length = gr.Slider(
98
+ minimum=1,
99
+ maximum=200,
100
+ value=50,
101
+ step=1,
102
+ label="Max Length"
103
+ )
104
+ temperature = gr.Slider(
105
+ minimum=0.1,
106
+ maximum=2.0,
107
+ value=0.7,
108
+ step=0.1,
109
+ label="Temperature"
110
+ )
111
+ top_k = gr.Slider(
112
+ minimum=1,
113
+ maximum=100,
114
+ value=40,
115
+ step=1,
116
+ label="Top K"
117
+ )
118
+
119
+ generate_btn = gr.Button("Generate Text")
120
+
121
+ with gr.Column():
122
+ output = gr.Textbox(
123
+ label="Generated Output",
124
+ lines=10,
125
+ placeholder="Your generated text will appear here..."
126
+ )
127
+
128
+ generate_btn.click(
129
+ fn=generate_text,
130
+ inputs=[prompt, max_length, temperature, top_k],
131
+ outputs=output
132
+ )
133
+
134
+ return demo
135
+
136
+ # Load the model when the script is run
137
+ load_model()
138
+
139
+ # Create and launch the interface
140
+ demo = create_interface()
141
+
142
+ if __name__ == "__main__":
143
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ safetensors