ShubhamBaghel309 commited on
Commit
4dcdb06
·
1 Parent(s): 4fdae9f

Fix merge conflicts in app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -45
app.py CHANGED
@@ -28,13 +28,12 @@ def load_model():
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model = GPTModel(GPT_CONFIG_124M)
30
 
31
- <<<<<<< HEAD
32
  try:
33
  # Download model from HuggingFace Model Hub
34
  print("📥 Downloading model from HuggingFace...")
35
  model_path = hf_hub_download(
36
  repo_id="ShubhamBaghel307/miniGPT-124M",
37
- filename="model.pth", # or "model_and_optimizer.pth"
38
  repo_type="model"
39
  )
40
 
@@ -52,15 +51,6 @@ def load_model():
52
  except Exception as e:
53
  print(f"⚠️ Error loading model: {e}")
54
  print("Using randomly initialized weights")
55
- =======
56
- # Try to load trained weights if available
57
- model_path = Path("model.pth")
58
- if model_path.exists():
59
- model.load_state_dict(torch.load(model_path, map_location=device))
60
- print("✅ Loaded trained model weights")
61
- else:
62
- print("⚠️ No trained weights found, using random initialization")
63
- >>>>>>> 9d99f6e730bfb0a7922a4a03324fc61f27387778
64
 
65
  model.to(device)
66
  model.eval()
@@ -114,14 +104,14 @@ def generate_text(prompt, max_new_tokens=50, temperature=1.0, top_k=50):
114
  return f"Error: {str(e)}"
115
 
116
  # Create Gradio interface
117
- with gr.Blocks(theme=gr.themes.Soft(), title="MiniGPT Chat") as demo:
118
  gr.Markdown(
119
  """
120
- # 🤖 MiniGPT Chat Interface
121
 
122
- A simple chat interface for your trained GPT model. Enter a prompt and adjust parameters to generate text.
123
 
124
- **Note:** This model is trained from scratch on limited data, so outputs may not be as coherent as ChatGPT!
125
  """
126
  )
127
 
@@ -129,12 +119,12 @@ with gr.Blocks(theme=gr.themes.Soft(), title="MiniGPT Chat") as demo:
129
  with gr.Column(scale=2):
130
  prompt_input = gr.Textbox(
131
  label="Enter your prompt",
132
- placeholder="When forty winters shall...",
133
  lines=3
134
  )
135
 
136
  with gr.Row():
137
- generate_btn = gr.Button("Generate Text", variant="primary", size="lg")
138
  clear_btn = gr.Button("Clear", size="lg")
139
 
140
  output_text = gr.Textbox(
@@ -144,15 +134,14 @@ with gr.Blocks(theme=gr.themes.Soft(), title="MiniGPT Chat") as demo:
144
  )
145
 
146
  with gr.Column(scale=1):
147
- gr.Markdown("### ⚙️ Generation Parameters")
148
 
149
  max_tokens = gr.Slider(
150
  minimum=10,
151
  maximum=200,
152
  value=50,
153
  step=10,
154
- label="Max New Tokens",
155
- info="Maximum number of tokens to generate"
156
  )
157
 
158
  temperature = gr.Slider(
@@ -160,8 +149,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="MiniGPT Chat") as demo:
160
  maximum=2.0,
161
  value=0.7,
162
  step=0.1,
163
- label="Temperature",
164
- info="Higher = more random, Lower = more focused"
165
  )
166
 
167
  top_k = gr.Slider(
@@ -169,24 +157,18 @@ with gr.Blocks(theme=gr.themes.Soft(), title="MiniGPT Chat") as demo:
169
  maximum=100,
170
  value=50,
171
  step=5,
172
- label="Top-K",
173
- info="Sample from top K tokens (0 = disabled)"
174
  )
175
 
176
- # Example prompts
177
- gr.Markdown("### 💡 Example Prompts")
178
  gr.Examples(
179
  examples=[
180
- ["Every effort moves you"],
181
- ["When forty winters shall"],
182
- ["The quick brown fox"],
183
  ["Once upon a time"],
 
184
  ["In a world where"],
185
  ],
186
  inputs=prompt_input,
187
  )
188
 
189
- # Event handlers
190
  generate_btn.click(
191
  fn=generate_text,
192
  inputs=[prompt_input, max_tokens, temperature, top_k],
@@ -195,22 +177,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="MiniGPT Chat") as demo:
195
 
196
  clear_btn.click(
197
  fn=lambda: ("", ""),
198
- inputs=None,
199
  outputs=[prompt_input, output_text]
200
  )
201
-
202
- # Footer
203
- gr.Markdown(
204
- """
205
- ---
206
- Built with ❤️ using [Gradio](https://gradio.app) | Model trained from scratch following "Build a Large Language Model (From Scratch)"
207
- """
208
- )
209
 
210
- # Launch the app
211
  if __name__ == "__main__":
212
- demo.launch(
213
- share=False,
214
- server_name="0.0.0.0",
215
- server_port=7860
216
- )
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model = GPTModel(GPT_CONFIG_124M)
30
 
 
31
  try:
32
  # Download model from HuggingFace Model Hub
33
  print("📥 Downloading model from HuggingFace...")
34
  model_path = hf_hub_download(
35
  repo_id="ShubhamBaghel307/miniGPT-124M",
36
+ filename="model_and_optimizer.pth",
37
  repo_type="model"
38
  )
39
 
 
51
  except Exception as e:
52
  print(f"⚠️ Error loading model: {e}")
53
  print("Using randomly initialized weights")
 
 
 
 
 
 
 
 
 
54
 
55
  model.to(device)
56
  model.eval()
 
104
  return f"Error: {str(e)}"
105
 
106
  # Create Gradio interface
107
+ with gr.Blocks(theme=gr.themes.Soft(), title="MiniGPT") as demo:
108
  gr.Markdown(
109
  """
110
+ # 🤖 MiniGPT - Text Generator
111
 
112
+ A GPT-2 style language model trained from scratch. Enter a prompt and watch it generate text!
113
 
114
+ **Model:** 124M parameters | **Context:** 256 tokens
115
  """
116
  )
117
 
 
119
  with gr.Column(scale=2):
120
  prompt_input = gr.Textbox(
121
  label="Enter your prompt",
122
+ placeholder="Once upon a time...",
123
  lines=3
124
  )
125
 
126
  with gr.Row():
127
+ generate_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
128
  clear_btn = gr.Button("Clear", size="lg")
129
 
130
  output_text = gr.Textbox(
 
134
  )
135
 
136
  with gr.Column(scale=1):
137
+ gr.Markdown("### ⚙️ Parameters")
138
 
139
  max_tokens = gr.Slider(
140
  minimum=10,
141
  maximum=200,
142
  value=50,
143
  step=10,
144
+ label="Max Tokens"
 
145
  )
146
 
147
  temperature = gr.Slider(
 
149
  maximum=2.0,
150
  value=0.7,
151
  step=0.1,
152
+ label="Temperature"
 
153
  )
154
 
155
  top_k = gr.Slider(
 
157
  maximum=100,
158
  value=50,
159
  step=5,
160
+ label="Top-K"
 
161
  )
162
 
 
 
163
  gr.Examples(
164
  examples=[
 
 
 
165
  ["Once upon a time"],
166
+ ["The future of AI is"],
167
  ["In a world where"],
168
  ],
169
  inputs=prompt_input,
170
  )
171
 
 
172
  generate_btn.click(
173
  fn=generate_text,
174
  inputs=[prompt_input, max_tokens, temperature, top_k],
 
177
 
178
  clear_btn.click(
179
  fn=lambda: ("", ""),
 
180
  outputs=[prompt_input, output_text]
181
  )
 
 
 
 
 
 
 
 
182
 
 
183
  if __name__ == "__main__":
184
+ demo.launch(server_name="0.0.0.0", server_port=7860)