huang342 commited on
Commit
b53720b
·
verified ·
1 Parent(s): f2c4f12

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +61 -18
  2. model.pth +1 -1
app.py CHANGED
@@ -10,7 +10,6 @@ Original file is located at
10
  import torch
11
  import torch.nn as nn
12
  from torch.nn import functional as F
13
- import gradio as gr
14
  import requests
15
 
16
  # hyperparameters
@@ -21,13 +20,14 @@ n_head = 4
21
  n_layer = 4
22
  dropout = 0.0
23
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
24
 
25
  # File path for saving the Book of Mormon text
26
  file_path = "Book of Mormon.txt"
27
 
28
- url = "https://raw.githubusercontent.com/huang-0505/LLM/refs/heads/main/Book%20of%20Mormon.txt"
29
-
30
  # Download and save the file
 
31
  response = requests.get(url)
32
  with open("Book of Mormon.txt", "w", encoding="utf-8") as f:
33
  f.write(response.text)
@@ -43,6 +43,22 @@ itos = {i: ch for i, ch in enumerate(chars)}
43
  encode = lambda s: [stoi[c] for c in s]
44
  decode = lambda l: ''.join([itos[i] for i in l])
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # Model definition
47
  class BigramLanguageModel(nn.Module):
48
  def __init__(self):
@@ -127,34 +143,61 @@ class FeedForward(nn.Module):
127
  def forward(self, x):
128
  return self.net(x)
129
 
130
- # Load pre-trained model
131
- model = BigramLanguageModel()
132
- model.load_state_dict(torch.load('model.pth', map_location=device))
133
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # Gradio functions
136
  def ask_question(question, max_new_tokens=100):
 
137
  context_text = f"Q: {question}\nA:"
138
- context_tokens = torch.tensor(encode(context_text), dtype=torch.long, device=device).unsqueeze(0)
 
 
139
  generated_tokens = model.generate(context_tokens, max_new_tokens=max_new_tokens)
 
 
140
  generated_text = decode(generated_tokens[0].tolist())
141
- return generated_text.split("A:")[1].strip()
142
 
 
 
 
 
 
143
  def chatbot_response(question):
144
- try:
145
- return ask_question(question)
146
- except Exception as e:
147
- return f"Error: {e}"
 
148
 
149
- # Gradio Interface
150
  demo = gr.Interface(
151
  fn=chatbot_response,
152
  inputs="text",
153
  outputs="text",
154
  title="Religious Chatbot",
155
- description="Ask questions about the Book of Mormon."
156
  )
157
 
158
  # Launch the app
159
- demo.launch()
160
-
 
10
  import torch
11
  import torch.nn as nn
12
  from torch.nn import functional as F
 
13
  import requests
14
 
15
  # hyperparameters
 
20
  n_layer = 4
21
  dropout = 0.0
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ learning_rate = 1e-3
24
+ max_iters = 5000 # Number of training iterations
25
 
26
  # File path for saving the Book of Mormon text
27
  file_path = "Book of Mormon.txt"
28
 
 
 
29
  # Download and save the file
30
+ url = "https://raw.githubusercontent.com/huang-0505/LLM/refs/heads/main/Book%20of%20Mormon.txt"
31
  response = requests.get(url)
32
  with open("Book of Mormon.txt", "w", encoding="utf-8") as f:
33
  f.write(response.text)
 
43
  encode = lambda s: [stoi[c] for c in s]
44
  decode = lambda l: ''.join([itos[i] for i in l])
45
 
46
+ # Encode the dataset
47
+ data = torch.tensor(encode(text), dtype=torch.long)
48
+
49
+ # Split into training and validation sets
50
+ n = int(0.9 * len(data)) # 90% training, 10% validation
51
+ train_data = data[:n]
52
+ val_data = data[n:]
53
+
54
+ # Function to get batches of data
55
+ def get_batch(split):
56
+ data = train_data if split == "train" else val_data
57
+ ix = torch.randint(len(data) - block_size, (batch_size,))
58
+ x = torch.stack([data[i:i + block_size] for i in ix])
59
+ y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
60
+ return x.to(device), y.to(device)
61
+
62
  # Model definition
63
  class BigramLanguageModel(nn.Module):
64
  def __init__(self):
 
143
  def forward(self, x):
144
  return self.net(x)
145
 
146
+ # Initialize model and optimizer
147
+ model = BigramLanguageModel().to(device)
148
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
149
+
150
+ # Training loop
151
+ for iter in range(max_iters):
152
+ xb, yb = get_batch("train")
153
+ logits, loss = model(xb, yb)
154
+ optimizer.zero_grad()
155
+ loss.backward()
156
+ optimizer.step()
157
+
158
+ if iter % 100 == 0:
159
+ print(f"Step {iter}: Loss = {loss.item()}")
160
+
161
+ # Save the model
162
+ torch.save(model.state_dict(), "model.pth")
163
+ print("Model trained and saved as 'model.pth'")
164
+
165
+ !pip install gradio
166
+
167
+ import gradio as gr
168
+
169
 
 
170
  def ask_question(question, max_new_tokens=100):
171
+ # Format the input context
172
  context_text = f"Q: {question}\nA:"
173
+ context_tokens = torch.tensor([encode(context_text)], dtype=torch.long, device=device)
174
+
175
+ # Generate the response
176
  generated_tokens = model.generate(context_tokens, max_new_tokens=max_new_tokens)
177
+
178
+ # Decode the generated tokens into text
179
  generated_text = decode(generated_tokens[0].tolist())
 
180
 
181
+ # Extract the answer (after "A:")
182
+ answer = generated_text.split("A:")[1].strip()
183
+ return answer
184
+
185
+ # Function to process the question
186
  def chatbot_response(question):
187
+ try:
188
+ answer = ask_question(question)
189
+ return f"Q: {question}\nA: {answer}"
190
+ except Exception as e:
191
+ return f"Error: {e}"
192
 
193
+ # Create a Gradio interface
194
  demo = gr.Interface(
195
  fn=chatbot_response,
196
  inputs="text",
197
  outputs="text",
198
  title="Religious Chatbot",
199
+ description="Ask questions about the book of Mormon, and the chatbot will generate answers based on its knowledge."
200
  )
201
 
202
  # Launch the app
203
+ demo.launch(share=True)
 
model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:faf7a48948c35f21225ee5871bf989f7a14d70a4cb1d138cd21702499cb7cb0d
3
  size 955314
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7a81949af5a132ffeeb6d9c6f0224663ebc79a4b64ac4254fc652b65280d478
3
  size 955314