DavidBazaldua commited on
Commit
abe09bc
·
verified ·
1 Parent(s): f7a0f58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -103
app.py CHANGED
@@ -11,7 +11,6 @@ MODEL_ID = "DavidBazaldua/llama3_finetuned_transformes"
11
  DEVICE = "cpu"
12
  DTYPE = torch.float32
13
 
14
- # Limit CPU threads
15
  torch.set_num_threads(2)
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -28,32 +27,28 @@ DEFAULT_SYSTEM_PROMPT = (
28
  "You are a helpful, precise AI assistant. "
29
  "Always answer as briefly as possible. "
30
  "For fact-based questions, answer in one short sentence or a compact bullet list. "
31
- "Do not add explanations, background, or restate the question unless the user explicitly asks for it. "
32
- "Respond in English unless the user explicitly requests another language."
33
  )
34
 
35
  # ---------------------------------------------------------------------
36
  # Prompt building
37
  # ---------------------------------------------------------------------
38
 
39
-
40
  def build_prompt(system_prompt, context, history, user_message):
41
  messages = []
42
 
43
- if system_prompt and system_prompt.strip():
44
  messages.append({"role": "system", "content": system_prompt})
45
 
46
- if context and context.strip():
47
- messages.append(
48
- {
49
- "role": "system",
50
- "content": (
51
- "The following information is additional context. "
52
- "Use it only if it is relevant to the user's request:\n"
53
- f"{context}"
54
- ),
55
- }
56
- )
57
 
58
  for user, assistant in history:
59
  messages.append({"role": "user", "content": user})
@@ -61,19 +56,18 @@ def build_prompt(system_prompt, context, history, user_message):
61
 
62
  messages.append({"role": "user", "content": user_message})
63
 
64
- prompt = tokenizer.apply_chat_template(
65
  messages,
66
  tokenize=False,
67
  add_generation_prompt=True,
68
  )
69
- return prompt
70
 
71
 
72
  def generate_answer(system_prompt, context, message, history, max_tokens, temperature, top_p):
73
  if history is None:
74
  history = []
75
 
76
- if not system_prompt or system_prompt.strip() == "":
77
  system_prompt = DEFAULT_SYSTEM_PROMPT
78
 
79
  max_tokens = int(min(max_tokens, 128))
@@ -87,7 +81,7 @@ def generate_answer(system_prompt, context, message, history, max_tokens, temper
87
  ).to(DEVICE)
88
 
89
  with torch.no_grad():
90
- output_tokens = model.generate(
91
  **inputs,
92
  max_new_tokens=max_tokens,
93
  do_sample=True,
@@ -96,90 +90,75 @@ def generate_answer(system_prompt, context, message, history, max_tokens, temper
96
  pad_token_id=tokenizer.eos_token_id,
97
  )
98
 
99
- full_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
100
 
101
- if full_text.startswith(prompt):
102
- answer = full_text[len(prompt):].strip()
103
- else:
104
- answer = full_text.strip()
105
 
106
- history = history + [[message, answer]]
107
  return answer, history
108
 
109
 
110
  def chat(message, history, system_prompt, context, max_tokens, temperature, top_p):
111
- if history is None:
112
- history = []
113
-
114
- answer, updated_history = generate_answer(
115
- system_prompt=system_prompt,
116
- context=context,
117
- message=message,
118
- history=history,
119
- max_tokens=max_tokens,
120
- temperature=temperature,
121
- top_p=top_p,
122
  )
123
-
124
- return "", updated_history
125
-
126
 
127
  # ---------------------------------------------------------------------
128
- # Minimalist Gradio UI
129
  # ---------------------------------------------------------------------
130
 
131
- with gr.Blocks(css="""
132
- /* Make the whole app look cleaner and more minimal */
133
- body { font-family: system-ui, -apple-system, BlinkMacSystemFont, sans-serif; }
134
- #chat-title { font-size: 1.6rem; font-weight: 500; margin-bottom: 0.25rem; }
135
- #chat-subtitle { font-size: 0.9rem; color: #666; margin-bottom: 1.5rem; }
136
-
137
- /* Tighten spacing around the chatbot */
138
- .gradio-container { max-width: 900px; margin: 0 auto; }
139
-
140
- """) as demo:
141
- with gr.Column():
142
- gr.Markdown(
143
- """
144
- <div id="chat-title">Iris</div>
145
- <div id="chat-subtitle">Minimal chat interface for your fine-tuned Llama 3 model.</div>
146
- """,
147
- elem_id="header",
148
- )
149
 
150
- chatbot = gr.Chatbot(
151
- label="",
152
- height=420,
153
- )
 
 
 
154
 
155
- msg = gr.Textbox(
156
- label="",
157
- placeholder="Type your message and press Enter...",
158
- )
159
 
160
- with gr.Row():
161
- send_btn = gr.Button("Send", variant="primary")
162
- clear_btn = gr.Button("Clear history")
163
-
164
- with gr.Accordion("Advanced settings", open=False):
165
  system_prompt_box = gr.Textbox(
166
  label="System prompt",
167
  value=DEFAULT_SYSTEM_PROMPT,
168
  lines=5,
169
  )
 
170
  context_box = gr.Textbox(
171
- label="Additional context",
172
- placeholder="Optional: paste any reference text or notes you want the model to use as context.",
173
  lines=6,
174
  )
175
 
176
  max_tokens_slider = gr.Slider(
177
- label="Max new tokens",
178
  minimum=32,
179
  maximum=256,
180
  value=128,
181
  step=16,
182
  )
 
183
  temperature_slider = gr.Slider(
184
  label="Temperature",
185
  minimum=0.1,
@@ -187,6 +166,7 @@ body { font-family: system-ui, -apple-system, BlinkMacSystemFont, sans-serif; }
187
  value=0.7,
188
  step=0.1,
189
  )
 
190
  top_p_slider = gr.Slider(
191
  label="Top-p",
192
  minimum=0.1,
@@ -195,37 +175,16 @@ body { font-family: system-ui, -apple-system, BlinkMacSystemFont, sans-serif; }
195
  step=0.05,
196
  )
197
 
198
- inputs = [
199
- msg,
200
- chatbot,
201
- system_prompt_box,
202
- context_box,
203
- max_tokens_slider,
204
- temperature_slider,
205
- top_p_slider,
206
- ]
207
- outputs = [msg, chatbot]
208
 
209
- msg.submit(
210
- fn=chat,
211
- inputs=inputs,
212
- outputs=outputs,
213
- )
214
 
215
- send_btn.click(
216
- fn=chat,
217
- inputs=inputs,
218
- outputs=outputs,
219
- )
220
-
221
- clear_btn.click(
222
- lambda: [],
223
- None,
224
- chatbot,
225
- queue=False,
226
- )
227
 
228
  if __name__ == "__main__":
229
  demo.launch()
230
 
231
-
 
11
  DEVICE = "cpu"
12
  DTYPE = torch.float32
13
 
 
14
  torch.set_num_threads(2)
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
27
  "You are a helpful, precise AI assistant. "
28
  "Always answer as briefly as possible. "
29
  "For fact-based questions, answer in one short sentence or a compact bullet list. "
30
+ "Do not add explanations unless the user explicitly asks for them. "
31
+ "Respond in English unless the user asks otherwise."
32
  )
33
 
34
  # ---------------------------------------------------------------------
35
  # Prompt building
36
  # ---------------------------------------------------------------------
37
 
 
38
  def build_prompt(system_prompt, context, history, user_message):
39
  messages = []
40
 
41
+ if system_prompt.strip():
42
  messages.append({"role": "system", "content": system_prompt})
43
 
44
+ if context.strip():
45
+ messages.append({
46
+ "role": "system",
47
+ "content": (
48
+ "The following information is additional context. "
49
+ "Use it only if relevant:\n" + context
50
+ )
51
+ })
 
 
 
52
 
53
  for user, assistant in history:
54
  messages.append({"role": "user", "content": user})
 
56
 
57
  messages.append({"role": "user", "content": user_message})
58
 
59
+ return tokenizer.apply_chat_template(
60
  messages,
61
  tokenize=False,
62
  add_generation_prompt=True,
63
  )
 
64
 
65
 
66
  def generate_answer(system_prompt, context, message, history, max_tokens, temperature, top_p):
67
  if history is None:
68
  history = []
69
 
70
+ if not system_prompt.strip():
71
  system_prompt = DEFAULT_SYSTEM_PROMPT
72
 
73
  max_tokens = int(min(max_tokens, 128))
 
81
  ).to(DEVICE)
82
 
83
  with torch.no_grad():
84
+ outputs = model.generate(
85
  **inputs,
86
  max_new_tokens=max_tokens,
87
  do_sample=True,
 
90
  pad_token_id=tokenizer.eos_token_id,
91
  )
92
 
93
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
 
95
+ # Try to extract only the new part
96
+ answer = decoded[len(prompt):].strip() if decoded.startswith(prompt) else decoded.strip()
 
 
97
 
98
+ history.append([message, answer])
99
  return answer, history
100
 
101
 
102
  def chat(message, history, system_prompt, context, max_tokens, temperature, top_p):
103
+ answer, history = generate_answer(
104
+ system_prompt, context, message, history, max_tokens, temperature, top_p
 
 
 
 
 
 
 
 
 
105
  )
106
+ return "", history
 
 
107
 
108
  # ---------------------------------------------------------------------
109
+ # Minimalist ChatGPT-style UI
110
  # ---------------------------------------------------------------------
111
 
112
+ CSS = """
113
+ #container {max-width: 1200px; margin-left: auto; margin-right: auto;}
114
+ #chat-column {width: 75%;}
115
+ #sidebar {width: 25%; padding-left: 20px;}
116
+ #input-row {margin-top: 12px;}
117
+ """
118
+
119
+ with gr.Blocks(css=CSS) as demo:
120
+ gr.Markdown("<h2 style='font-weight:600;'>Iris – Your Fine-Tuned Llama 3 Assistant</h2>")
121
+
122
+ with gr.Row(elem_id="container"):
123
+ # LEFT SIDE: CHAT
124
+ with gr.Column(elem_id="chat-column"):
125
+ chatbot = gr.Chatbot(
126
+ height=500,
127
+ show_label=False,
128
+ )
 
129
 
130
+ with gr.Row(elem_id="input-row"):
131
+ msg = gr.Textbox(
132
+ placeholder="Send a message...",
133
+ scale=8,
134
+ show_label=False,
135
+ )
136
+ send_btn = gr.Button("Send", scale=2)
137
 
138
+ # RIGHT SIDE: SIDEBAR
139
+ with gr.Column(elem_id="sidebar"):
140
+ gr.Markdown("### Settings")
 
141
 
 
 
 
 
 
142
  system_prompt_box = gr.Textbox(
143
  label="System prompt",
144
  value=DEFAULT_SYSTEM_PROMPT,
145
  lines=5,
146
  )
147
+
148
  context_box = gr.Textbox(
149
+ label="Context",
150
+ placeholder="Optional reference text...",
151
  lines=6,
152
  )
153
 
154
  max_tokens_slider = gr.Slider(
155
+ label="Max tokens",
156
  minimum=32,
157
  maximum=256,
158
  value=128,
159
  step=16,
160
  )
161
+
162
  temperature_slider = gr.Slider(
163
  label="Temperature",
164
  minimum=0.1,
 
166
  value=0.7,
167
  step=0.1,
168
  )
169
+
170
  top_p_slider = gr.Slider(
171
  label="Top-p",
172
  minimum=0.1,
 
175
  step=0.05,
176
  )
177
 
178
+ clear_btn = gr.Button("Clear chat")
 
 
 
 
 
 
 
 
 
179
 
180
+ # Chat events
181
+ inputs = [msg, chatbot, system_prompt_box, context_box, max_tokens_slider, temperature_slider, top_p_slider]
182
+ outputs = [msg, chatbot]
 
 
183
 
184
+ msg.submit(chat, inputs, outputs)
185
+ send_btn.click(chat, inputs, outputs)
186
+ clear_btn.click(lambda: [], None, chatbot)
 
 
 
 
 
 
 
 
 
187
 
188
  if __name__ == "__main__":
189
  demo.launch()
190