AlissenMoreno61 commited on
Commit
ac2b76e
Β·
verified Β·
1 Parent(s): ba4ce6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -271
app.py CHANGED
@@ -1,278 +1,169 @@
1
- import torch
2
- import torch.nn.functional as F
3
  import gradio as gr
4
- from typing import List, Tuple
5
- from transformers import (
6
- AutoTokenizer,
7
- AutoModelForSeq2SeqLM,
8
- )
9
-
10
-
11
- # -----------------------------
12
- # Models
13
- # -----------------------------
14
- # Seq2Seq for generation (use selected FLAN-T5) and summarization (DistilBART)
15
- GEN_MODEL_NAME = "google/flan-t5-small"
16
- gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
17
- gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
18
-
19
- SUM_MODEL_NAME = "sshleifer/distilbart-cnn-12-6"
20
- sum_tokenizer = AutoTokenizer.from_pretrained(SUM_MODEL_NAME)
21
- sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUM_MODEL_NAME)
22
-
23
-
24
- # -----------------------------
25
- # Helpers
26
- # -----------------------------
27
- def _format_token(token_text: str) -> str:
28
- """Format a decoded token for table display."""
29
- if token_text.strip() == "":
30
- # visualize whitespace-only tokens
31
- shown = token_text.replace(" ", "␠") or "(space)"
32
- else:
33
- shown = token_text
34
- return f"<code class='tok'>{shown}</code>"
35
-
36
-
37
- def _alternatives_table(
38
- chosen_token_ids: List[int],
39
- scores: List[torch.Tensor], # list of [batch, vocab] logits for each step
40
- tokenizer,
41
- top_k: int = 5,
42
- special_ok: bool = False,
43
- ) -> str:
44
- """Build an HTML table listing alternatives per generated step.
45
-
46
- chosen_token_ids should align 1:1 with scores.
47
- """
48
- rows: List[str] = ["<table class='prediction-table'>",
49
- "<tr><th>Generated Token</th><th>Top Alternatives</th></tr>"]
50
-
51
- special_ids = set(tokenizer.all_special_ids or [])
52
- for step, (chosen_id, step_scores) in enumerate(zip(chosen_token_ids, scores)):
53
- probs = F.softmax(step_scores[0], dim=-1) if step_scores.dim() == 2 else F.softmax(step_scores, dim=-1)
54
-
55
- # Get a surplus then filter out the chosen token and (optionally) specials
56
- k_surplus = max(top_k + 10, 20)
57
- top_vals, top_idx = torch.topk(probs, k=min(k_surplus, probs.numel()))
58
-
59
- alts: List[Tuple[str, float]] = []
60
- for idx, p in zip(top_idx.tolist(), top_vals.tolist()):
61
- if idx == chosen_id:
62
- continue
63
- if not special_ok and idx in special_ids:
64
- continue
65
- token_text = tokenizer.decode([idx])
66
- if token_text == "":
67
- continue
68
- alts.append((token_text, float(p)))
69
- if len(alts) >= top_k:
70
- break
71
-
72
- chosen_text = tokenizer.decode([chosen_id])
73
- chosen_fmt = _format_token(chosen_text)
74
- if alts:
75
- alt_fmt = ", ".join(f"{_format_token(t)} <span class='p'>({p*100:.1f}%)</span>" for t, p in alts)
76
- else:
77
- alt_fmt = "β€”"
78
- rows.append(f"<tr><td class='token-cell'>{chosen_fmt}</td><td class='pred-cell'>{alt_fmt}</td></tr>")
79
-
80
- rows.append("</table>")
81
- return "\n".join(rows)
82
-
83
-
84
- # -----------------------------
85
- # Inference functions
86
- # -----------------------------
87
- def generate_with_alternatives(prompt: str, max_tokens: int, temperature: float, top_k: int):
88
- if not prompt.strip():
89
- return "", ""
90
-
91
- inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True)
92
-
93
- with torch.no_grad():
94
- outputs = gen_model.generate(
95
- **inputs,
96
- max_new_tokens=int(max_tokens),
97
- do_sample=True,
98
- temperature=float(temperature),
99
- top_p=0.9,
100
- repetition_penalty=1.1,
101
- return_dict_in_generate=True,
102
- output_scores=True,
103
- )
104
-
105
- # For encoder-decoder, sequences contain decoder tokens; first is usually decoder_start
106
- seq = outputs.sequences[0]
107
- text_out = gen_tokenizer.decode(seq, skip_special_tokens=True)
108
-
109
- # Align chosen tokens with scores (skip the decoder start token)
110
- chosen_ids = seq[1: 1 + len(outputs.scores)].tolist()
111
-
112
- alt_html = _alternatives_table(
113
- chosen_token_ids=chosen_ids,
114
- scores=outputs.scores,
115
- tokenizer=gen_tokenizer,
116
- top_k=int(top_k),
117
  )
118
-
119
- alt_wrapper = f"""
120
- <div class='response-box fade-in'>
121
- <span class='response-label'>Alternative Next Tokens (Generation)</span>
122
- {alt_html}
123
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  """
125
-
126
- return text_out, alt_wrapper
127
-
128
-
129
- def summarize_with_alternatives(text: str, max_tokens: int, num_beams: int, top_k: int):
130
- if not text.strip():
131
- return "", ""
132
-
133
- inputs = sum_tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
134
-
135
- with torch.no_grad():
136
- outputs = sum_model.generate(
137
- **inputs,
138
- max_new_tokens=int(max_tokens),
139
- num_beams=int(num_beams),
140
- early_stopping=True,
141
- return_dict_in_generate=True,
142
- output_scores=True,
143
- )
144
-
145
- seq = outputs.sequences[0]
146
- summary_text = sum_tokenizer.decode(seq, skip_special_tokens=True)
147
-
148
- # For Seq2Seq, scores length should match tokens after decoder start token
149
- # Align chosen token ids with scores (skip the first token which is usually decoder_start)
150
- chosen_ids = seq[1: 1 + len(outputs.scores)].tolist()
151
-
152
- alt_html = _alternatives_table(
153
- chosen_token_ids=chosen_ids,
154
- scores=outputs.scores,
155
- tokenizer=sum_tokenizer,
156
- top_k=int(top_k),
157
- )
158
-
159
- alt_wrapper = f"""
160
- <div class='response-box fade-in'>
161
- <span class='response-label'>Alternative Next Tokens (Summarization)</span>
162
- {alt_html}
163
- </div>
164
- """
165
-
166
- return summary_text, alt_wrapper
167
-
168
-
169
- # -----------------------------
170
- # UI
171
- # -----------------------------
172
- custom_css = """
173
- body {
174
- background: linear-gradient(135deg, #f3cadb, #f6b8d2, #f7a8c9);
175
- }
176
- .gradio-container { background: transparent !important; }
177
-
178
- textarea, input {
179
- background-color: #f2c8da !important;
180
- color: #2b0f1a !important;
181
- border-radius: 12px !important;
182
- border: 1px solid #8e2c4a !important;
183
- font-size: 16px !important;
184
- }
185
- button {
186
- background: #d76c91 !important;
187
- color: #2b0f1a !important;
188
- font-weight: bold !important;
189
- border-radius: 12px !important;
190
- border: 1px solid #8e2c4a !important;
191
- }
192
- button:hover {
193
- transform: scale(1.03);
194
- background: #c55b82 !important;
195
- }
196
-
197
- .response-box {
198
- background-color: rgba(255, 230, 235, 0.9);
199
- border: 2px solid #ffb6c1;
200
- border-radius: 15px;
201
- padding: 16px;
202
- color: #4b2b30;
203
- font-size: 15px;
204
- line-height: 1.6;
205
- margin-top: 10px;
206
- box-shadow: 3px 4px 8px rgba(255, 182, 193, 0.3);
207
- white-space: normal;
208
- width: 100%;
209
- opacity: 0;
210
- transition: opacity 0.8s ease-in-out;
211
- }
212
- .fade-in { opacity: 1 !important; }
213
- .response-label {
214
- font-weight: bold;
215
- color: #d36b83;
216
- font-size: 18px;
217
- font-family: "Poppins", sans-serif;
218
- margin-bottom: 8px;
219
- display: block;
220
- border-bottom: 1.5px dashed #ffc9d6;
221
- padding-bottom: 5px;
222
- text-align: center;
223
- }
224
-
225
- .prediction-table { width: 100%; border-collapse: collapse; margin-top: 6px; }
226
- .prediction-table th {
227
- background-color: #ffe6eb;
228
- color: #d36b83;
229
- text-align: center;
230
- font-family: "Poppins", sans-serif;
231
- padding: 6px;
232
- border-bottom: 2px dashed #ffc9d6;
233
- }
234
- .prediction-table td { padding: 6px 10px; color: #4b2b30; border-bottom: 1px solid #ffd3de; vertical-align: top; }
235
- .token-cell { font-weight: bold; text-align: right; width: 25%; color: #c45c77; padding-right: 10px; }
236
- .pred-cell { text-align: left; width: 75%; }
237
- .tok { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
238
- .p { color: #997; }
239
- """
240
-
241
-
242
- with gr.Blocks(css=custom_css) as demo:
243
- gr.Markdown("<h1>🌸 App 3: Generate & Summarize with Alternatives</h1>")
244
- gr.Markdown("<p>Create text or summaries, and inspect alternative tokens considered at each step.</p>")
245
-
246
  with gr.Tab("πŸ’¬ Text Generation"):
247
- gen_prompt = gr.Textbox(label="Enter your prompt:")
248
- with gr.Row():
249
- gen_max_tokens = gr.Slider(10, 300, value=120, step=10, label="Max New Tokens")
250
- gen_temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature")
251
- gen_top_k = gr.Slider(2, 10, value=5, step=1, label="Top-k Alternatives")
252
- with gr.Row():
253
- gen_output = gr.Textbox(label="Generated Output", lines=8)
254
- gen_alts = gr.HTML(label="Alternatives", value="")
255
- gr.Button("✨ Generate").click(
256
- generate_with_alternatives,
257
- inputs=[gen_prompt, gen_max_tokens, gen_temperature, gen_top_k],
258
- outputs=[gen_output, gen_alts],
259
- )
260
 
 
261
  with gr.Tab("🧁 Text Summarization"):
262
- sum_text = gr.Textbox(label="Paste your text:", lines=8)
263
- with gr.Row():
264
- sum_max_tokens = gr.Slider(20, 200, value=120, step=10, label="Max New Tokens")
265
- sum_beams = gr.Slider(1, 6, value=4, step=1, label="Beams")
266
- sum_top_k = gr.Slider(2, 10, value=5, step=1, label="Top-k Alternatives")
267
- with gr.Row():
268
- sum_output = gr.Textbox(label="Summary", lines=8)
269
- sum_alts = gr.HTML(label="Alternatives", value="")
270
- gr.Button("πŸŽ€ Summarize").click(
271
- summarize_with_alternatives,
272
- inputs=[sum_text, sum_max_tokens, sum_beams, sum_top_k],
273
- outputs=[sum_output, sum_alts],
274
- )
275
-
276
 
277
- if __name__ == "__main__":
278
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # 🌷 Load improved, instruction-tuned models for quality responses
5
+ generator = pipeline("text2text-generation", model="google/flan-t5-small")
6
+ summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
7
+
8
+ # πŸ’¬ Generate creative or factual text
9
+ def generate_text(prompt, max_tokens, temperature):
10
+ output = generator(
11
+ prompt,
12
+ max_new_tokens=int(max_tokens),
13
+ temperature=float(temperature),
14
+ top_p=0.9,
15
+ repetition_penalty=2.0,
16
+ do_sample=True, # πŸ’– enables randomness and exploration
17
+ num_return_sequences=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
+ return output[0]["generated_text"]
20
+
21
+
22
+
23
+
24
+ # βœ‚οΈ Summarize text
25
+ def summarize_text(text):
26
+ summary = summarizer(text, max_length=150, min_length=40, do_sample=False)
27
+ return summary[0]["summary_text"]
28
+
29
+ # 🌸 Build Gradio Interface
30
+ with gr.Blocks(
31
+ theme=gr.themes.Soft(primary_hue="pink"),
32
+ css="""
33
+ body {
34
+ background: linear-gradient(135deg, #f3cadb, #f6b8d2, #f7a8c9);
35
+ font-family: 'Poppins', sans-serif;
36
+ color: #3b1f2b;
37
+ }
38
+
39
+ h1 {
40
+ text-align: center;
41
+ font-size: 2.5em;
42
+ color: #8e2c4a;
43
+ text-shadow: 0 0 6px rgba(142,44,74,0.25);
44
+ }
45
+
46
+ p {
47
+ text-align: center;
48
+ color: #3b1f2b !important;
49
+ }
50
+
51
+ /* Info box */
52
+ .helper-box {
53
+ background: #f8d7e2;
54
+ border-radius: 18px;
55
+ padding: 20px;
56
+ margin: 25px auto;
57
+ width: 85%;
58
+ box-shadow: 0 0 12px rgba(122, 41, 64, 0.25);
59
+ line-height: 1.6;
60
+ color: #3b1f2b;
61
+ }
62
+
63
+ .helper-box h3 {
64
+ color: #8e2c4a;
65
+ text-align: center;
66
+ font-size: 1.3em;
67
+ }
68
+
69
+ .helper-box b {
70
+ color: #6e1b3f;
71
+ }
72
+
73
+ .helper-box i, .helper-box span, .helper-box li {
74
+ color: #3b1f2b !important;
75
+ }
76
+
77
+ /* Labels */
78
+ label {
79
+ color: #4a1e33 !important;
80
+ font-weight: 600;
81
+ }
82
+
83
+ /* Input areas */
84
+ .gr-textbox textarea {
85
+ background-color: #f2c8da !important;
86
+ border-radius: 12px !important;
87
+ color: #2b0f1a !important;
88
+ }
89
+
90
+ #response-box textarea {
91
+ min-height: 220px !important;
92
+ background-color: #f3c6d9 !important;
93
+ border-radius: 12px !important;
94
+ padding: 10px !important;
95
+ font-size: 1em !important;
96
+ color: #2b0f1a !important;
97
+ box-shadow: inset 0 0 6px rgba(122, 41, 64, 0.25);
98
+ }
99
+
100
+ /* Buttons */
101
+ .gr-button {
102
+ background: #d76c91 !important;
103
+ color: #2b0f1a !important;
104
+ font-weight: bold;
105
+ border-radius: 12px !important;
106
+ border: 1px solid #8e2c4a !important;
107
+ }
108
+
109
+ .gr-button:hover {
110
+ transform: scale(1.05);
111
+ background: #c55b82 !important;
112
+ box-shadow: 0 0 8px rgba(142,44,74,0.4);
113
+ }
114
+
115
+ /* Tabs */
116
+ .tab-nav > button {
117
+ color: #3b1f2b !important;
118
+ font-weight: 600;
119
+ background-color: transparent !important;
120
+ }
121
+
122
+ .tab-nav > button.selected {
123
+ color: #8e2c4a !important;
124
+ border-bottom: 3px solid #8e2c4a !important;
125
+ }
126
+
127
+ ul li {
128
+ margin: 6px 0;
129
+ }
130
  """
131
+ ) as app:
132
+
133
+ # 🌸 Title
134
+ gr.Markdown("<h1>🌸 AI Assistant</h1>")
135
+ gr.Markdown("<p>✨ Generate or Summarize Text using Pretrained Hugging Face Models ✨</p>")
136
+
137
+ # πŸ’‘ Instructions Box
138
+ gr.HTML("""
139
+ <div class="helper-box">
140
+ <h3>πŸ’‘ How It Works</h3>
141
+ <p style="color:#3b1f2b;">
142
+ Choose between <b>text generation</b> or <b>text summarization</b> below!<br>
143
+ This app uses improved open-source models for more natural and reliable text output.
144
+ </p>
145
+ <ul>
146
+ <li><b>Text Generation:</b> <span>Type a creative prompt, like <i>β€œWrite a story about a butterfly in Texas.”</i></span></li>
147
+ <li><b>Text Summarization:</b> <span>Paste a paragraph or article to get a clear, shorter summary.</span></li>
148
+ </ul>
149
+ <p style="color:#3b1f2b;">πŸŽ€ <b>Max Tokens:</b> Longer responses use more tokens.</p>
150
+ <p style="color:#3b1f2b;">🌑️ <b>Temperature:</b> Controls creativity β€” lower = logical, higher = imaginative.</p>
151
+ </div>
152
+ """)
153
+
154
+ # πŸ’¬ Text Generation Tab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  with gr.Tab("πŸ’¬ Text Generation"):
156
+ prompt = gr.Textbox(label="πŸ’– Enter your prompt:")
157
+ max_tokens = gr.Slider(50, 500, value=150, step=10, label="πŸŽ€ Max Tokens")
158
+ temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="🌑️ Temperature")
159
+ generation_output = gr.Textbox(label="πŸ’¬ Generated Response", lines=10, elem_id="response-box")
160
+ gr.Button("✨ Generate ✨").click(generate_text, [prompt, max_tokens, temperature], generation_output)
 
 
 
 
 
 
 
 
161
 
162
+ # 🧁 Text Summarization Tab
163
  with gr.Tab("🧁 Text Summarization"):
164
+ input_text = gr.Textbox(label="πŸ“ Paste your text:", lines=10)
165
+ summary_output = gr.Textbox(label="🍬 Summary", lines=10, elem_id="response-box")
166
+ gr.Button("πŸŽ€ Summarize πŸŽ€").click(summarize_text, [input_text], summary_output)
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ # πŸš€ Launch
169
+ app.launch()