doggdad commited on
Commit
6cd256f
·
verified ·
1 Parent(s): a5e7768

Upload app_gradio_alternative.py

Browse files
Files changed (1) hide show
  1. app_gradio_alternative.py +134 -0
app_gradio_alternative.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2
+ # Source for "Build a Large Language Model From Scratch"
3
+ # - https://www.manning.com/books/build-a-large-language-model-from-scratch
4
+ # Code: https://github.com/rasbt/LLMs-from-scratch
5
+
6
+ from pathlib import Path
7
+ import sys
8
+
9
+ import tiktoken
10
+ import torch
11
+ import gradio as gr
12
+
13
+ # For llms_from_scratch installation instructions, see:
14
+ # https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg
15
+ from previous_chapters import GPTModel
16
+
17
+ from previous_chapters import (
18
+ generate,
19
+ text_to_token_ids,
20
+ token_ids_to_text,
21
+ )
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+
26
+ def get_model_and_tokenizer():
27
+ """
28
+ Code to load a GPT-2 model with finetuned weights generated in chapter 7.
29
+ This requires that you run the code in chapter 7 first, which generates the necessary gpt2-medium355M-sft.pth file.
30
+ """
31
+
32
+ GPT_CONFIG_355M = {
33
+ "vocab_size": 50257, # Vocabulary size
34
+ "context_length": 1024, # Shortened context length (orig: 1024)
35
+ "emb_dim": 768, # Embedding dimension
36
+ "n_heads": 12, # Number of attention heads
37
+ "n_layers": 12, # Number of layers
38
+ "drop_rate": 0.0, # Dropout rate
39
+ "qkv_bias": True # Query-key-value bias
40
+ }
41
+
42
+ tokenizer = tiktoken.get_encoding("gpt2")
43
+
44
+ # For local development
45
+ model_path = Path("gpt2-small124M-sft.pth")
46
+
47
+ # For Hugging Face deployment
48
+ hf_model_path = Path("gpt2-small124M-sft.pth")
49
+
50
+ # Try loading from the Hugging Face model path first, then fall back to local
51
+ if hf_model_path.exists():
52
+ model_path = hf_model_path
53
+ elif not model_path.exists():
54
+ print(
55
+ f"Could not find the model file. Please run the chapter 7 code "
56
+ "to generate the gpt2-medium355M-sft.pth file or upload it to this directory."
57
+ )
58
+ sys.exit()
59
+
60
+ checkpoint = torch.load(model_path, weights_only=True)
61
+ model = GPTModel(GPT_CONFIG_355M)
62
+ model.load_state_dict(checkpoint)
63
+ model.to(device)
64
+ model.eval() # Set to evaluation mode
65
+
66
+ return tokenizer, model, GPT_CONFIG_355M
67
+
68
+
69
+ def extract_response(response_text, input_text):
70
+ return response_text[len(input_text):].replace("### Response:", "").strip()
71
+
72
+
73
+ # Load model and tokenizer
74
+ tokenizer, model, model_config = get_model_and_tokenizer()
75
+
76
+
77
+ def generate_response(message, max_new_tokens=100):
78
+ """Generate a response using the fine-tuned GPT model"""
79
+ torch.manual_seed(123)
80
+
81
+ prompt = f"""Below is an instruction that describes a task. Write a response
82
+ that appropriately completes the request.
83
+
84
+ ### Instruction:
85
+ {message}
86
+ """
87
+
88
+ with torch.no_grad(): # Ensure no gradients are computed during inference
89
+ token_ids = generate(
90
+ model=model,
91
+ idx=text_to_token_ids(prompt, tokenizer).to(device),
92
+ max_new_tokens=max_new_tokens,
93
+ context_size=model_config["context_length"],
94
+ eos_id=50256
95
+ )
96
+
97
+ text = token_ids_to_text(token_ids, tokenizer)
98
+ response = extract_response(text, prompt)
99
+
100
+ return response
101
+
102
+
103
+ # Create a custom chat interface without using ChatInterface class
104
+ def respond(message, chat_history):
105
+ bot_message = generate_response(message)
106
+ chat_history.append((message, bot_message))
107
+ return "", chat_history
108
+
109
+
110
+ with gr.Blocks(theme="soft") as demo:
111
+ gr.Markdown("# Fine-tuned GPT Model Chat")
112
+ gr.Markdown("Chat with a fine-tuned GPT model from 'Build a Large Language Model From Scratch' by Sebastian Raschka")
113
+
114
+ chatbot = gr.Chatbot(height=600)
115
+ msg = gr.Textbox(placeholder="Ask me something...", container=False, scale=7)
116
+ clear = gr.Button("Clear")
117
+
118
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
119
+ clear.click(lambda: [], None, chatbot)
120
+
121
+ gr.Examples(
122
+ examples=[
123
+ "What is the capital of France",
124
+ "What is the opposite of 'wet'?",
125
+ "Write a short poem about AI",
126
+ "Explain the concept of attention in neural networks"
127
+ ],
128
+ inputs=msg
129
+ )
130
+
131
+
132
+ # Launch the interface
133
+ if __name__ == "__main__":
134
+ demo.launch(share=True)