Files changed (3) hide show
  1. README.md +6 -6
  2. app.py +109 -536
  3. requirements.txt +6 -2
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: TinkerSpace
3
- emoji:
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
 
7
  app_file: app.py
8
- pinned: true
9
- short_description: Demos for some my finetunes
10
- sdk_version: 6.2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Granite Vision 3.1 2B
3
+ emoji: 👀
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.15.0
8
  app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
 
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,549 +1,122 @@
1
- import re
2
  import spaces
 
 
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- from threading import Thread
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
-
8
- try:
9
- import nltk
10
- from nltk.tokenize import word_tokenize
11
- from nltk.chunk import ne_chunk
12
- from nltk.tag import pos_tag
13
- NLTK_AVAILABLE = True
14
- except ImportError:
15
- NLTK_AVAILABLE = False
16
-
17
-
18
- SYSTEM_PROMPT = """You are an expert creative director specializing in visual descriptions for image generation.
19
-
20
- Your task: Transform the user's concept into a rich, detailed image description while PRESERVING their core idea.
21
-
22
- IMPORTANT RULES:
23
- 1. Keep ALL key elements (intents, entities) from the original concept
24
- 2. Enhance with artistic details, NOT change the fundamental idea
25
- 3. Maintain the user's intended subject, action, and setting
26
-
27
- You should elaborate on:
28
- • Visual composition and perspective (bird's eye, close-up, wide angle, etc.)
29
- • Artistic style (photorealistic, impressionist, specific artist like Van Gogh, etc.)
30
- • Color palette and color temperature
31
- • Lighting (golden hour, dramatic shadows, soft diffused, etc.)
32
- • Atmosphere and mood
33
- • Textures and materials (rough, smooth, metallic, organic, etc.)
34
- • Technical details (medium, brushwork, rendering style)
35
- • Environmental context (time of day, weather, season, era)
36
- • Level of detail and focus points
37
-
38
- Output format: A single, flowing paragraph that reads naturally as an image prompt."""
39
-
40
- CUDA_AVAILABLE = False
41
-
42
- models = {}
43
- tokenizers = {}
44
-
45
- models[False] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1")
46
- tokenizers[False] = AutoTokenizer.from_pretrained("shb777/PromptTuner-v0.1")
47
- models[False].eval()
48
-
49
- if CUDA_AVAILABLE:
50
- models[True] = AutoModelForCausalLM.from_pretrained("shb777/PromptTuner-v0.1").to('cuda')
51
- tokenizers[True] = tokenizers[False]
52
- models[True].eval()
53
-
54
- # Download NLTK data
55
- if NLTK_AVAILABLE:
56
- try:
57
- nltk.data.find('tokenizers/punkt')
58
- nltk.data.find('taggers/averaged_perceptron_tagger')
59
- nltk.data.find('chunkers/maxent_ne_chunker')
60
- nltk.data.find('corpora/words')
61
- except LookupError:
62
- nltk.download('punkt', quiet=True)
63
- nltk.download('averaged_perceptron_tagger', quiet=True)
64
- nltk.download('maxent_ne_chunker', quiet=True)
65
- nltk.download('words', quiet=True)
66
-
67
-
68
- def extract_key_phrases(text: str) -> list:
69
- if not NLTK_AVAILABLE:
70
- words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower())
71
- return list(set(words))
72
-
73
- phrases = []
74
- try:
75
- tokens = word_tokenize(text)
76
- tagged = pos_tag(tokens)
77
- chunks = ne_chunk(tagged)
78
-
79
- current_phrase = []
80
- for chunk in chunks:
81
- if hasattr(chunk, 'label'):
82
- phrase = ' '.join([token for token, _ in chunk.leaves()])
83
- phrases.append(phrase.lower())
84
- elif chunk[1].startswith('NN'):
85
- current_phrase.append(chunk[0])
86
- elif chunk[1].startswith('JJ') and current_phrase:
87
- current_phrase.append(chunk[0])
88
- else:
89
- if current_phrase:
90
- phrases.append(' '.join(current_phrase).lower())
91
- current_phrase = []
92
-
93
- if current_phrase:
94
- phrases.append(' '.join(current_phrase).lower())
95
-
96
- for word, tag in tagged:
97
- if tag.startswith('JJ') or tag in ('RB', 'RBR', 'RBS'):
98
- phrases.append(word.lower())
99
-
100
- except Exception:
101
- words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower())
102
- phrases = list(set(words))
103
-
104
- # Also include original multi-word phrases
105
- multi_word = re.findall(r'\b[a-zA-Z]{3,}(?:\s+[a-zA-Z]{3,}){1,3}\b', text)
106
- phrases.extend([mw.lower() for mw in multi_word])
107
-
108
- # Sort by length (longer first) and remove duplicates
109
- phrases = list(set(phrases))
110
- phrases.sort(key=len, reverse=True)
111
-
112
- return phrases[:20]
113
-
114
-
115
- def highlight_matches(original_input: str, enhanced_output: str) -> str:
116
- if not original_input.strip():
117
- return f'<p class="output-text">{enhanced_output}</p>'
118
-
119
- key_phrases = extract_key_phrases(original_input)
120
- if not key_phrases:
121
- return f'<p class="output-text">{enhanced_output}</p>'
122
-
123
- # Sort by length (longer phrases first)
124
- key_phrases.sort(key=len, reverse=True)
125
-
126
- output = enhanced_output
127
- highlighted_spans = []
128
-
129
- for phrase in key_phrases:
130
- pattern = re.compile(r'\b' + re.escape(phrase) + r'\b', re.IGNORECASE)
131
-
132
- def replace_with_highlight(match):
133
- matched_text = match.group(0)
134
- start = match.start()
135
- # Skip if already highlighted
136
- for h_start, h_end in highlighted_spans:
137
- if start >= h_start and start <= h_end:
138
- return matched_text
139
- highlighted_spans.append((start, match.end()))
140
- return f'<mark class="highlight-keyword">{matched_text}</mark>'
141
-
142
- output = pattern.sub(replace_with_highlight, output)
143
-
144
- return f'<p class="output-text">{output}</p>'
145
-
146
-
147
- @spaces.GPU(duration=30)
148
- def generate_gpu(inputs, generation_kwargs):
149
- return models[True].generate(**inputs, **generation_kwargs)
150
-
151
-
152
- def enhance_prompt(user_prompt: str, use_gpu=CUDA_AVAILABLE):
153
- """Enhance the user's prompt using the AI model."""
154
- # Validate input
155
- if not user_prompt or not user_prompt.strip():
156
- yield (
157
- '<span class="placeholder-text">Please enter a prompt to enhance.</span>',
158
- "",
159
- gr.update(interactive=True),
160
- gr.update(interactive=True)
161
- )
162
- return
163
-
164
- # Prepare messages
165
- messages = [
166
- {"role": "system", "content": SYSTEM_PROMPT},
167
- {"role": "user", "content": user_prompt}
168
- ]
169
-
170
- use_gpu = use_gpu and CUDA_AVAILABLE
171
- tokenizer = tokenizers[False]
172
-
173
- # Tokenize input
174
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
175
- inputs = tokenizer(prompt, return_tensors="pt")
176
-
177
- if use_gpu:
178
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
179
-
180
- # Set up streaming
181
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
182
  generation_kwargs = {
183
- 'max_new_tokens': 512,
184
- 'streamer': streamer,
185
- 'do_sample': True,
186
- 'temperature': 1,
187
- 'top_p': 0.95,
188
- 'top_k': 64
189
  }
190
 
191
- # Show loading state
192
- placeholder = '<span class="placeholder-text">Your enhanced prompt will appear here</span>'
193
- yield placeholder, "", gr.update(interactive=False), gr.update(interactive=False)
194
-
195
- try:
196
- # Start generation in a separate thread
197
- if use_gpu:
198
- thread = Thread(target=generate_gpu, kwargs={'inputs': inputs, 'generation_kwargs': generation_kwargs})
199
- else:
200
- thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs})
201
-
202
- thread.start()
203
-
204
- # Stream output
205
- output = ""
206
- for text in streamer:
207
- output += text
208
- highlighted = highlight_matches(user_prompt, output)
209
- yield highlighted, output, gr.update(), gr.update()
210
-
211
- except gr.exceptions.Error as e:
212
- if use_gpu:
213
- gr.Warning(str(e))
214
- gr.Info('Retrying with CPU')
215
- inputs = {k: v.cpu() for k, v in inputs.items()}
216
- thread = Thread(target=models[False].generate, kwargs={**inputs, **generation_kwargs})
217
- thread.start()
218
-
219
- output = ""
220
- for text in streamer:
221
- output += text
222
- highlighted = highlight_matches(user_prompt, output)
223
- yield highlighted, output, gr.update(), gr.update()
224
- else:
225
- raise gr.Error(e)
226
-
227
- # Final output with interactive buttons restored
228
- final_highlighted = highlight_matches(user_prompt, output)
229
- yield final_highlighted, output, gr.update(interactive=True), gr.update(interactive=True)
230
-
231
-
232
- # =============================================================================
233
- # CSS - shadcn/ui inspired Zinc Dark Theme
234
- # =============================================================================
235
- custom_css = """
236
- /* ========== CSS VARIABLES ========== */
237
- :root {
238
- --background: 240 10% 3.9%;
239
- --foreground: 0 0% 98%;
240
- --card: 240 10% 4.5%;
241
- --card-border: 240 3.7% 18%;
242
- --primary: 0 0% 98%;
243
- --primary-foreground: 240 5.9% 10%;
244
- --secondary: 240 3.7% 15.9%;
245
- --secondary-foreground: 0 0% 98%;
246
- --muted: 240 3.7% 15.9%;
247
- --muted-foreground: 240 5% 64.9%;
248
- --accent: 240 3.7% 15.9%;
249
- --accent-foreground: 0 0% 98%;
250
- --border: 240 3.7% 18%;
251
- --input: 240 3.7% 18%;
252
- --ring: 240 5.9% 85%;
253
- --radius: 0.625rem;
254
- }
255
-
256
- /* ========== GLOBAL STYLES ========== */
257
- .gradio-container {
258
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
259
- background: hsl(var(--background)) !important;
260
- color: hsl(var(--foreground));
261
- }
262
-
263
- .gradio-container mark {
264
- background: hsl(var(--accent) / 0.6);
265
- color: hsl(var(--accent-foreground));
266
- padding: 0.15em 0.35em;
267
- border-radius: calc(var(--radius) - 2px);
268
- font-weight: 500;
269
- border: 1px solid hsl(var(--border) / 0.5);
270
- }
271
-
272
- footer { display: none !important; }
273
-
274
- /* ========== MARKDOWN ========== */
275
- .gradio-markdown {
276
- color: hsl(var(--foreground)) !important;
277
- font-size: 0.9375rem !important;
278
- line-height: 1.6 !important;
279
- }
280
-
281
- .gradio-markdown:first-child {
282
- margin-bottom: 2rem;
283
- padding-bottom: 1.5rem;
284
- border-bottom: 1px solid hsl(var(--border));
285
- }
286
-
287
- .gradio-markdown:last-child {
288
- padding-top: 1.5rem;
289
- border-top: 1px solid hsl(var(--border));
290
- color: hsl(var(--muted-foreground)) !important;
291
- }
292
-
293
- .gradio-markdown a {
294
- color: hsl(var(--foreground)) !important;
295
- text-decoration: none;
296
- border-bottom: 1px solid hsl(var(--border));
297
- transition: border-color 0.2s ease;
298
- }
299
-
300
- .gradio-markdown a:hover {
301
- border-color: hsl(var(--ring));
302
- }
303
-
304
- /* ========== LAYOUT ========== */
305
- .main-grid {
306
- display: grid;
307
- grid-template-columns: 1fr 1fr;
308
- gap: 2rem;
309
- }
310
-
311
- @media (max-width: 768px) {
312
- .main-grid { grid-template-columns: 1fr; }
313
- }
314
-
315
- /* ========== CARDS ========== */
316
- .card {
317
- background: hsl(var(--card));
318
- border: 1px solid hsl(var(--card-border));
319
- border-radius: var(--radius);
320
- padding: 1.5rem;
321
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.3), 0 0 0 1px rgba(255, 255, 255, 0.02) inset;
322
- }
323
-
324
- /* ========== FORM ELEMENTS ========== */
325
- .form-label {
326
- font-size: 0.875rem;
327
- font-weight: 500;
328
- margin-bottom: 0.5rem;
329
- display: block;
330
- color: hsl(var(--foreground));
331
- }
332
-
333
- .input-textarea {
334
- width: 100%;
335
- min-height: 140px;
336
- padding: 0.875rem;
337
- font-size: 0.9375rem;
338
- line-height: 1.6;
339
- background: hsl(var(--background));
340
- border: 1px solid hsl(var(--input));
341
- border-radius: var(--radius);
342
- color: hsl(var(--foreground));
343
- transition: all 0.2s ease;
344
- resize: vertical;
345
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
346
- }
347
-
348
- .input-textarea::placeholder {
349
- color: hsl(var(--muted-foreground) / 0.7);
350
- }
351
-
352
- .input-textarea:focus {
353
- outline: none;
354
- border-color: hsl(var(--ring));
355
- box-shadow: 0 0 0 3px hsl(var(--ring) / 0.1), 0 1px 2px rgba(0, 0, 0, 0.2);
356
- background: hsl(var(--background) / 0.8);
357
- }
358
-
359
- /* ========== BUTTONS ========== */
360
- .btn {
361
- display: inline-flex;
362
- align-items: center;
363
- justify-content: center;
364
- gap: 0.5rem;
365
- font-size: 0.9375rem;
366
- font-weight: 500;
367
- padding: 0.625rem 1.25rem;
368
- border-radius: var(--radius);
369
- cursor: pointer;
370
- transition: all 0.2s ease;
371
- border: none;
372
- }
373
-
374
- .btn:focus-visible {
375
- outline: none;
376
- box-shadow: 0 0 0 2px hsl(var(--background)), 0 0 0 4px hsl(var(--ring));
377
- }
378
-
379
- .btn-primary {
380
- background: hsl(var(--primary));
381
- color: hsl(var(--primary-foreground));
382
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.05) inset;
383
- }
384
-
385
- .btn-primary:hover {
386
- opacity: 0.95;
387
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.25), 0 0 0 1px rgba(255, 255, 255, 0.08) inset;
388
- }
389
-
390
- .btn-primary:active {
391
- transform: translateY(1px);
392
- }
393
-
394
- .btn-primary:disabled {
395
- opacity: 0.5;
396
- cursor: not-allowed;
397
- }
398
-
399
- .btn-secondary {
400
- background: hsl(var(--secondary));
401
- color: hsl(var(--secondary-foreground));
402
- border: 1px solid hsl(var(--border));
403
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
404
- }
405
-
406
- .btn-secondary:hover {
407
- background: hsl(var(--secondary) / 0.8);
408
- border-color: hsl(var(--muted-foreground) / 0.5);
409
- }
410
-
411
- .btn-secondary:active {
412
- transform: translateY(1px);
413
- }
414
-
415
- /* ========== OUTPUT CONTAINER ========== */
416
- .output-container {
417
- min-height: 140px;
418
- padding: 0.875rem;
419
- border: 1px solid hsl(var(--input));
420
- border-radius: var(--radius);
421
- background: hsl(var(--background));
422
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.15), 0 0 0 1px rgba(255, 255, 255, 0.02) inset;
423
- }
424
-
425
- .output-text {
426
- color: hsl(var(--foreground));
427
- font-size: 0.9375rem;
428
- line-height: 1.75;
429
- margin: 0;
430
- }
431
-
432
- .placeholder-text {
433
- color: hsl(var(--muted-foreground));
434
- }
435
-
436
- .highlight-keyword {
437
- background: hsl(var(--accent) / 0.6);
438
- color: hsl(var(--accent-foreground));
439
- padding: 0.15em 0.35em;
440
- border-radius: calc(var(--radius) - 2px);
441
- font-weight: 500;
442
- border: 1px solid hsl(var(--border) / 0.5);
443
- }
444
-
445
- /* ========== EXAMPLES ========== */
446
- .examples-section {
447
- padding: 1.5rem;
448
- background: hsl(var(--card));
449
- border: 1px solid hsl(var(--card-border));
450
- border-radius: var(--radius);
451
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2), 0 0 0 1px rgba(255, 255, 255, 0.02) inset;
452
- }
453
-
454
- /* ========== SPACING UTILITIES ========== */
455
- .mt-6 { margin-top: 1.5rem; }
456
- .flex { display: flex; }
457
- .gap-2 { gap: 0.5rem; }
458
- """
459
-
460
- # =============================================================================
461
- # Gradio Interface
462
- # =============================================================================
463
- with gr.Blocks(css=custom_css, title="Prompt Enhancer") as demo:
464
- # Header
465
- with gr.Row():
466
- gr.Markdown("Transform your creative ideas into detailed, vivid prompts for AI image generation.")
467
-
468
- # Main content - two column layout
469
- with gr.Row(elem_classes=["main-grid"]):
470
- # Input column
471
- with gr.Column(elem_classes=["card"]):
472
- gr.HTML('<label class="form-label">Input Prompt</label>')
473
-
474
- input_text = gr.Textbox(
475
- placeholder="Describe your image concept... e.g., fox, red tail, blue moon, clouds",
476
- lines=5,
477
- show_label=False,
478
- autofocus=True,
479
- container=False,
480
- elem_classes=["input-textarea"]
481
- )
482
-
483
- with gr.Row(elem_classes=["flex gap-2 mt-6"]):
484
- enhance_btn = gr.Button(
485
- "Enhance Prompt",
486
- variant="primary",
487
- scale=2,
488
- elem_classes=["btn", "btn-primary"]
489
- )
490
- clear_btn = gr.Button(
491
- "Clear",
492
- scale=1,
493
- elem_classes=["btn", "btn-secondary"]
494
- )
495
-
496
- # Output column
497
- with gr.Column(elem_classes=["card"]):
498
- gr.HTML('<label class="form-label">Enhanced Prompt</label>')
499
-
500
- output_html = gr.HTML(
501
- value='<span class="placeholder-text">Your enhanced prompt will appear here</span>',
502
- elem_classes=["output-container"]
503
- )
504
-
505
- raw_output = gr.Textbox(visible=False)
506
-
507
- # Examples section
508
- with gr.Column(elem_classes=["examples-section"]):
509
- gr.Examples(
510
- examples=[
511
- ["fox, red tail, blue moon, clouds"],
512
- ["room with french window, cozy morning vibes, minimal"],
513
- ["anime style, sunset, japan"]
514
- ],
515
- inputs=input_text,
516
- label="Examples"
517
- )
518
-
519
- # Footer
520
  with gr.Row():
521
- gr.Markdown(
522
- "Powered by [PromptTuner](https://huggingface.co/shb777/PromptTuner-v0.1), "
523
- "a finetuned gemma3-270M model specifically designed to enhance text prompts "
524
- "for text-to-image generation."
525
- )
526
-
527
- # =============================================================================
528
- # Event Handlers
529
- # =============================================================================
530
- enhance_btn.click(
531
- fn=enhance_prompt,
532
- inputs=[input_text, gr.State(False)],
533
- outputs=[output_html, raw_output, enhance_btn, clear_btn]
 
 
 
 
 
 
 
 
 
534
  )
535
 
536
- clear_btn.click(
537
- fn=lambda: (
538
- "",
539
- '<span class="placeholder-text">Your enhanced prompt will appear here</span>',
540
- "",
541
- gr.update(interactive=True),
542
- gr.update(interactive=True)
543
- ),
544
  inputs=None,
545
- outputs=[input_text, output_html, raw_output, enhance_btn, clear_btn]
 
 
 
 
 
 
 
546
  )
547
 
548
  if __name__ == "__main__":
549
- demo.queue(max_size=20).launch(mcp_server=True)
 
 
1
  import spaces
2
+ import random
3
+ import torch
4
  import gradio as gr
5
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
6
+
7
+ model_path = "ibm-granite/granite-vision-3.1-2b-preview"
8
+ processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True)
9
+ model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
10
+
11
+ def get_text_from_content(content):
12
+ texts = []
13
+ for item in content:
14
+ if item["type"] == "text":
15
+ texts.append(item["text"])
16
+ elif item["type"] == "image":
17
+ texts.append("<image>")
18
+ return " ".join(texts)
19
+
20
+ @spaces.GPU
21
+ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
22
+ if conversation is None:
23
+ conversation = []
24
+
25
+ user_content = []
26
+ if image is not None:
27
+ user_content.append({"type": "image", "image": image})
28
+ if text and text.strip():
29
+ user_content.append({"type": "text", "text": text.strip()})
30
+ if not user_content:
31
+ return conversation_display(conversation), conversation
32
+
33
+ conversation.append({
34
+ "role": "user",
35
+ "content": user_content
36
+ })
37
+
38
+ inputs = processor.apply_chat_template(
39
+ conversation,
40
+ add_generation_prompt=True,
41
+ tokenize=True,
42
+ return_dict=True,
43
+ return_tensors="pt"
44
+ ).to("cuda")
45
+
46
+ torch.manual_seed(random.randint(0, 10000))
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  generation_kwargs = {
49
+ "max_new_tokens": max_tokens,
50
+ "temperature": temperature,
51
+ "top_p": top_p,
52
+ "top_k": top_k,
53
+ "do_sample": True,
 
54
  }
55
 
56
+ output = model.generate(**inputs, **generation_kwargs)
57
+ assistant_response = processor.decode(output[0], skip_special_tokens=True)
58
+
59
+ conversation.append({
60
+ "role": "assistant",
61
+ "content": [{"type": "text", "text": assistant_response.strip()}]
62
+ })
63
+
64
+ return conversation_display(conversation), conversation
65
+
66
+ def conversation_display(conversation):
67
+ chat_history = []
68
+ for msg in conversation:
69
+ if msg["role"] == "user":
70
+ user_text = get_text_from_content(msg["content"])
71
+ elif msg["role"] == "assistant":
72
+ assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
73
+ chat_history.append({"role": "user", "content": user_text})
74
+ chat_history.append({"role": "assistant", "content": assistant_text})
75
+ return chat_history
76
+
77
+ def clear_chat():
78
+ return [], [], "", None
79
+
80
+ with gr.Blocks(title="Granite Vision 3.1 2B", css="h1 { overflow: hidden; }") as demo:
81
+ gr.Markdown("# Granite Vision 3.1 2B")
82
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  with gr.Row():
84
+ with gr.Column(scale=2):
85
+ image_input = gr.Image(type="pil", label="Upload Image (optional)")
86
+ with gr.Column():
87
+ temperature_input = gr.Slider(minimum=0.0, maximum=2.0, value=0.2, step=0.01, label="Temperature")
88
+ top_p_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top p")
89
+ top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top k")
90
+ max_tokens_input = gr.Slider(minimum=10, maximum=300, value=128, step=1, label="Max Tokens")
91
+
92
+ with gr.Column(scale=3):
93
+ chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
94
+ text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
95
+ with gr.Row():
96
+ send_button = gr.Button("Chat")
97
+ clear_button = gr.Button("Clear Chat")
98
+
99
+
100
+ state = gr.State([])
101
+
102
+ send_button.click(
103
+ chat_inference,
104
+ inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input, state],
105
+ outputs=[chatbot, state]
106
  )
107
 
108
+ clear_button.click(
109
+ clear_chat,
 
 
 
 
 
 
110
  inputs=None,
111
+ outputs=[chatbot, state, text_input, image_input]
112
+ )
113
+
114
+ gr.Examples(
115
+ examples=[
116
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"]
117
+ ],
118
+ inputs=[image_input, text_input]
119
  )
120
 
121
  if __name__ == "__main__":
122
+ demo.launch()
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
- transformers
2
- nltk
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ git+https://github.com/huggingface/transformers.git
4
+ gradio
5
+ accelerate
6
+ bitsandbytes