[upload demo notebook]

#2
by prithivMLmods - opened
deepattricap-vla-3b-colab-notebook-demo/DeepAttriCap_VLA_3B.ipynb ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "## **DeepAttriCap-VLA-3B**\n"
7
+ ],
8
+ "metadata": {
9
+ "id": "XN7ploRWa4GE"
10
+ }
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {
15
+ "id": "uFovmijgUV1Z"
16
+ },
17
+ "source": [
18
+ "The DeepAttriCap-VLA-3B model is a fine-tuned version of Qwen2.5-VL-3B-Instruct, tailored for Vision-Language Attribution and Image Captioning. This variant is designed to generate precise, attribute-rich descriptions that define the visual properties of objects and scenes in detail, ensuring both object-level identification and contextual captioning. Vision-Language Attribution: Produces structured captions with explicit object attributes, properties, and contextual details.\n",
19
+ "\n",
20
+ "\n",
21
+ "| IMG 1 | IMG 2 |\n",
22
+ "|-------|-------|\n",
23
+ "| ![Screenshot 2025-08-28 at 21-30-25 Gradio.png](https://cdn-uploads.huggingface.co/production/uploads/65bb837dbfb878f46c77de4c/sOsUkjrn4ElKetpQ_ap6t.png) | ![Screenshot 2025-08-28 at 21-30-52 Gradio.png](https://cdn-uploads.huggingface.co/production/uploads/65bb837dbfb878f46c77de4c/IXxvpDXVyFXcnLSQoP1NN.png) |\n",
24
+ "\n",
25
+ "\n",
26
+ "notebook by : [prithivMLmods](https://huggingface.co/prithivMLmods)πŸ€—"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "markdown",
31
+ "metadata": {
32
+ "id": "RugX4SGZV-8O"
33
+ },
34
+ "source": [
35
+ "### **Install Packages**"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {
42
+ "id": "l-NtFtjSpuJQ"
43
+ },
44
+ "outputs": [],
45
+ "source": [
46
+ "%%capture\n",
47
+ "!pip install git+https://github.com/huggingface/transformers.git \\\n",
48
+ " git+https://github.com/huggingface/accelerate.git \\\n",
49
+ " git+https://github.com/huggingface/peft.git \\\n",
50
+ " transformers-stream-generator huggingface_hub albumentations \\\n",
51
+ " pyvips-binary qwen-vl-utils sentencepiece opencv-python docling-core \\\n",
52
+ " python-docx torchvision safetensors matplotlib num2words \\\n",
53
+ "\n",
54
+ "!pip install xformers requests pymupdf hf_xet spaces pyvips pillow gradio \\\n",
55
+ " einops torch fpdf timm av decord bitsandbytes reportlab\n",
56
+ "#Hold tight, this will take around 2-3 minutes."
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {
62
+ "id": "mvoSnRZcVBu4"
63
+ },
64
+ "source": [
65
+ "### **Run DeepAttriCap-VLA-3B Demo**\n",
66
+ "\n"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "from huggingface_hub import notebook_login, HfApi\n",
73
+ "notebook_login()"
74
+ ],
75
+ "metadata": {
76
+ "id": "db6vq0T8deoB"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "metadata": {
85
+ "id": "tElKr2Fkp1bO"
86
+ },
87
+ "outputs": [],
88
+ "source": [
89
+ "import spaces\n",
90
+ "import json\n",
91
+ "import math\n",
92
+ "import os\n",
93
+ "import traceback\n",
94
+ "from io import BytesIO\n",
95
+ "from typing import Any, Dict, List, Optional, Tuple\n",
96
+ "import re\n",
97
+ "import time\n",
98
+ "from threading import Thread\n",
99
+ "from io import BytesIO\n",
100
+ "import uuid\n",
101
+ "import tempfile\n",
102
+ "\n",
103
+ "import gradio as gr\n",
104
+ "import requests\n",
105
+ "import torch\n",
106
+ "from PIL import Image\n",
107
+ "import fitz\n",
108
+ "\n",
109
+ "from transformers import (\n",
110
+ " Qwen2_5_VLForConditionalGeneration,\n",
111
+ " AutoProcessor,\n",
112
+ " TextIteratorStreamer,\n",
113
+ ")\n",
114
+ "\n",
115
+ "from transformers.image_utils import load_image\n",
116
+ "\n",
117
+ "from reportlab.lib.pagesizes import A4\n",
118
+ "from reportlab.lib.styles import getSampleStyleSheet\n",
119
+ "from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer\n",
120
+ "from reportlab.lib.units import inch\n",
121
+ "\n",
122
+ "# --- Constants and Model Setup ---\n",
123
+ "MAX_INPUT_TOKEN_LENGTH = 4096\n",
124
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
125
+ "\n",
126
+ "# --- System Prompt Definition ---\n",
127
+ "CAPTION_SYSTEM_PROMPT = \"\"\"\n",
128
+ "You are an AI assistant that rigorously follows this response protocol:\n",
129
+ "\n",
130
+ "1. For every input image, your primary task is to write a **precise caption**. The caption must capture the **essence of the image** in clear, concise, and contextually accurate language.\n",
131
+ "\n",
132
+ "2. Along with the caption, provide a structured set of **attributes** that describe the visual elements. Attributes should include details such as objects, people, actions, colors, environment, mood, and other notable characteristics.\n",
133
+ "\n",
134
+ "3. Always include a **class_name** field. This must represent the **core theme or main subject** of the image in a compact format.\n",
135
+ " - Use the syntax: `{class_name==write_the_core_theme}`\n",
136
+ " - Example: `{class_name==dog_playing}` or `{class_name==city_sunset}`\n",
137
+ "\n",
138
+ "4. Maintain the following strict format in your output:\n",
139
+ " - **Caption:** <one-sentence description>\n",
140
+ " - **Attributes:** <comma-separated list of visual attributes>\n",
141
+ " - **{class_name==core_theme}**\n",
142
+ "\n",
143
+ "5. Ensure captions are **precise, neutral, and descriptive**, avoiding unnecessary elaboration or subjective interpretation unless explicitly required.\n",
144
+ "\n",
145
+ "6. Do not reference the rules or instructions in the output. Only return the formatted caption, attributes, and class_name.\n",
146
+ "\n",
147
+ "\"\"\".strip()\n",
148
+ "\n",
149
+ "\n",
150
+ "print(\"CUDA_VISIBLE_DEVICES=\", os.environ.get(\"CUDA_VISIBLE_DEVICES\"))\n",
151
+ "print(\"torch.__version__ =\", torch.__version__)\n",
152
+ "print(\"torch.version.cuda =\", torch.version.cuda)\n",
153
+ "print(\"cuda available:\", torch.cuda.is_available())\n",
154
+ "print(\"cuda device count:\", torch.cuda.device_count())\n",
155
+ "if torch.cuda.is_available():\n",
156
+ " print(\"current device:\", torch.cuda.current_device())\n",
157
+ " print(\"device name:\", torch.cuda.get_device_name(torch.cuda.current_device()))\n",
158
+ "\n",
159
+ "print(\"Using device:\", device)\n",
160
+ "\n",
161
+ "# --- Model Loading: prithivMLmods/DeepAttriCap-VLA-3B ---\n",
162
+ "MODEL_ID_N = \"prithivMLmods/DeepAttriCap-VLA-3B\"\n",
163
+ "processor = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)\n",
164
+ "model = Qwen2_5_VLForConditionalGeneration.from_pretrained(\n",
165
+ " MODEL_ID_N, trust_remote_code=True, dtype=torch.float16\n",
166
+ ").to(device).eval()\n",
167
+ "\n",
168
+ "\n",
169
+ "# --- PDF Generation and Preview Utility Function ---\n",
170
+ "def generate_and_preview_pdf(image: Image.Image, text_content: str, font_size: int, line_spacing: float, alignment: str, image_size: str):\n",
171
+ " \"\"\"\n",
172
+ " Generates a PDF, saves it, and then creates image previews of its pages.\n",
173
+ " Returns the path to the PDF and a list of paths to the preview images.\n",
174
+ " \"\"\"\n",
175
+ " if image is None or not text_content or not text_content.strip():\n",
176
+ " raise gr.Error(\"Cannot generate PDF. Image or text content is missing.\")\n",
177
+ "\n",
178
+ " # --- 1. Generate the PDF ---\n",
179
+ " temp_dir = tempfile.gettempdir()\n",
180
+ " pdf_filename = os.path.join(temp_dir, f\"output_{uuid.uuid4()}.pdf\")\n",
181
+ " doc = SimpleDocTemplate(\n",
182
+ " pdf_filename,\n",
183
+ " pagesize=A4,\n",
184
+ " rightMargin=inch, leftMargin=inch,\n",
185
+ " topMargin=inch, bottomMargin=inch\n",
186
+ " )\n",
187
+ " styles = getSampleStyleSheet()\n",
188
+ " style_normal = styles[\"Normal\"]\n",
189
+ " style_normal.fontSize = int(font_size)\n",
190
+ " style_normal.leading = int(font_size) * line_spacing\n",
191
+ " style_normal.alignment = {\"Left\": 0, \"Center\": 1, \"Right\": 2, \"Justified\": 4}[alignment]\n",
192
+ "\n",
193
+ " story = []\n",
194
+ "\n",
195
+ " img_buffer = BytesIO()\n",
196
+ " image.save(img_buffer, format='PNG')\n",
197
+ " img_buffer.seek(0)\n",
198
+ "\n",
199
+ " page_width, _ = A4\n",
200
+ " available_width = page_width - 2 * inch\n",
201
+ " image_widths = {\n",
202
+ " \"Small\": available_width * 0.3,\n",
203
+ " \"Medium\": available_width * 0.6,\n",
204
+ " \"Large\": available_width * 0.9,\n",
205
+ " }\n",
206
+ " img_width = image_widths[image_size]\n",
207
+ " img = RLImage(img_buffer, width=img_width, height=image.height * (img_width / image.width))\n",
208
+ " story.append(img)\n",
209
+ " story.append(Spacer(1, 12))\n",
210
+ "\n",
211
+ " cleaned_text = re.sub(r'#+\\s*', '', text_content).replace(\"*\", \"\")\n",
212
+ " text_paragraphs = cleaned_text.split('\\n')\n",
213
+ "\n",
214
+ " for para in text_paragraphs:\n",
215
+ " if para.strip():\n",
216
+ " story.append(Paragraph(para, style_normal))\n",
217
+ "\n",
218
+ " doc.build(story)\n",
219
+ "\n",
220
+ " # --- 2. Render PDF pages as images for preview ---\n",
221
+ " preview_images = []\n",
222
+ " try:\n",
223
+ " pdf_doc = fitz.open(pdf_filename)\n",
224
+ " for page_num in range(len(pdf_doc)):\n",
225
+ " page = pdf_doc.load_page(page_num)\n",
226
+ " pix = page.get_pixmap(dpi=150)\n",
227
+ " preview_img_path = os.path.join(temp_dir, f\"preview_{uuid.uuid4()}_p{page_num}.png\")\n",
228
+ " pix.save(preview_img_path)\n",
229
+ " preview_images.append(preview_img_path)\n",
230
+ " pdf_doc.close()\n",
231
+ " except Exception as e:\n",
232
+ " print(f\"Error generating PDF preview: {e}\")\n",
233
+ "\n",
234
+ " return pdf_filename, preview_images\n",
235
+ "\n",
236
+ "\n",
237
+ "# --- Core Application Logic ---\n",
238
+ "@spaces.GPU\n",
239
+ "def process_document_stream(\n",
240
+ " image: Image.Image,\n",
241
+ " prompt_input: str,\n",
242
+ " max_new_tokens: int,\n",
243
+ " temperature: float,\n",
244
+ " top_p: float,\n",
245
+ " top_k: int,\n",
246
+ " repetition_penalty: float\n",
247
+ "):\n",
248
+ " \"\"\"\n",
249
+ " Main generator function that handles model inference tasks with advanced generation parameters.\n",
250
+ " \"\"\"\n",
251
+ " if image is None:\n",
252
+ " yield \"Please upload an image.\", \"\"\n",
253
+ " return\n",
254
+ " if not prompt_input or not prompt_input.strip():\n",
255
+ " yield \"Please enter a prompt.\", \"\"\n",
256
+ " return\n",
257
+ "\n",
258
+ " # Integrate the system prompt\n",
259
+ " messages = [\n",
260
+ " {\"role\": \"system\", \"content\": CAPTION_SYSTEM_PROMPT},\n",
261
+ " {\"role\": \"user\", \"content\": [{\"type\": \"image\", \"image\": image}, {\"type\": \"text\", \"text\": prompt_input}]}\n",
262
+ " ]\n",
263
+ "\n",
264
+ " prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
265
+ " inputs = processor(text=[prompt_full], images=[image], return_tensors=\"pt\", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)\n",
266
+ "\n",
267
+ " streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)\n",
268
+ "\n",
269
+ " generation_kwargs = {\n",
270
+ " **inputs,\n",
271
+ " \"streamer\": streamer,\n",
272
+ " \"max_new_tokens\": max_new_tokens,\n",
273
+ " \"temperature\": temperature,\n",
274
+ " \"top_p\": top_p,\n",
275
+ " \"top_k\": top_k,\n",
276
+ " \"repetition_penalty\": repetition_penalty,\n",
277
+ " \"do_sample\": True\n",
278
+ " }\n",
279
+ "\n",
280
+ " thread = Thread(target=model.generate, kwargs=generation_kwargs)\n",
281
+ " thread.start()\n",
282
+ "\n",
283
+ " buffer = \"\"\n",
284
+ " for new_text in streamer:\n",
285
+ " buffer += new_text\n",
286
+ " buffer = buffer.replace(\"<|im_end|>\", \"\")\n",
287
+ " time.sleep(0.01)\n",
288
+ " yield buffer , buffer\n",
289
+ "\n",
290
+ " yield buffer, buffer\n",
291
+ "\n",
292
+ "\n",
293
+ "# --- Gradio UI Definition ---\n",
294
+ "def create_gradio_interface():\n",
295
+ " \"\"\"Builds and returns the Gradio web interface.\"\"\"\n",
296
+ " css = \"\"\"\n",
297
+ " .main-container { max-width: 1400px; margin: 0 auto; }\n",
298
+ " .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}\n",
299
+ " .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }\n",
300
+ " #gallery { min-height: 400px; }\n",
301
+ " \"\"\"\n",
302
+ " with gr.Blocks(theme=\"bethecloud/storj_theme\", css=css) as demo:\n",
303
+ " gr.HTML(\"\"\"\n",
304
+ " <div class=\"title\" style=\"text-align: center\">\n",
305
+ " <h1>DeepAttriCap-VLA-3B πŸ‘€</h1>\n",
306
+ " <p style=\"font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;\">\n",
307
+ " Using DeepAttriCap-VLA-3B for Image Captioning and Understanding\n",
308
+ " </p>\n",
309
+ " </div>\n",
310
+ " \"\"\")\n",
311
+ "\n",
312
+ " with gr.Row():\n",
313
+ " # Left Column (Inputs)\n",
314
+ " with gr.Column(scale=1):\n",
315
+ " gr.Textbox(\n",
316
+ " label=\"Selected Model\",\n",
317
+ " value=\"DeepAttriCap-VLA-3B\",\n",
318
+ " interactive=False\n",
319
+ " )\n",
320
+ " prompt_input = gr.Textbox(label=\"Query Input\", placeholder=\"✦︎ Enter your query\", value=\"Describe the image!\")\n",
321
+ " image_input = gr.Image(label=\"Upload Image\", type=\"pil\", sources=['upload'])\n",
322
+ "\n",
323
+ " with gr.Accordion(\"Advanced Settings\", open=False):\n",
324
+ " max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label=\"Max New Tokens\")\n",
325
+ " temperature = gr.Slider(label=\"Temperature\", minimum=0.1, maximum=4.0, step=0.1, value=0.6)\n",
326
+ " top_p = gr.Slider(label=\"Top-p (nucleus sampling)\", minimum=0.05, maximum=1.0, step=0.05, value=0.9)\n",
327
+ " top_k = gr.Slider(label=\"Top-k\", minimum=1, maximum=1000, step=1, value=50)\n",
328
+ " repetition_penalty = gr.Slider(label=\"Repetition penalty\", minimum=1.0, maximum=2.0, step=0.05, value=1.2)\n",
329
+ "\n",
330
+ " gr.Markdown(\"### PDF Export Settings\")\n",
331
+ " font_size = gr.Dropdown(choices=[\"8\", \"10\", \"12\", \"14\", \"16\", \"18\"], value=\"12\", label=\"Font Size\")\n",
332
+ " line_spacing = gr.Dropdown(choices=[1.0, 1.15, 1.5, 2.0], value=1.15, label=\"Line Spacing\")\n",
333
+ " alignment = gr.Dropdown(choices=[\"Left\", \"Center\", \"Right\", \"Justified\"], value=\"Justified\", label=\"Text Alignment\")\n",
334
+ " image_size = gr.Dropdown(choices=[\"Small\", \"Medium\", \"Large\"], value=\"Medium\", label=\"Image Size in PDF\")\n",
335
+ "\n",
336
+ " process_btn = gr.Button(\"πŸš€ Process Image\", variant=\"primary\", elem_classes=[\"process-button\"], size=\"lg\")\n",
337
+ " clear_btn = gr.Button(\"πŸ—‘οΈ Clear All\", variant=\"secondary\")\n",
338
+ "\n",
339
+ " # Right Column (Outputs)\n",
340
+ " with gr.Column(scale=2):\n",
341
+ " with gr.Tabs() as tabs:\n",
342
+ " with gr.Tab(\"πŸ“ Extracted Content\"):\n",
343
+ " raw_output_stream = gr.Textbox(label=\"Raw Model Output Stream\", interactive=False, lines=15, show_copy_button=True)\n",
344
+ " with gr.Tab(\"πŸ“° README.md\"):\n",
345
+ " with gr.Accordion(\"(Result.md)\", open=True):\n",
346
+ " markdown_output = gr.Markdown()\n",
347
+ "\n",
348
+ " with gr.Tab(\"πŸ“‹ PDF Preview\"):\n",
349
+ " generate_pdf_btn = gr.Button(\"πŸ“„ Generate PDF & Render\", variant=\"primary\")\n",
350
+ " pdf_output_file = gr.File(label=\"Download Generated PDF\", interactive=False)\n",
351
+ " pdf_preview_gallery = gr.Gallery(label=\"PDF Page Preview\", show_label=True, elem_id=\"gallery\", columns=2, object_fit=\"contain\", height=\"auto\")\n",
352
+ "\n",
353
+ " # Event Handlers\n",
354
+ " def clear_all_outputs():\n",
355
+ " return None, \"\", \"Raw output will appear here.\", \"\", None, None\n",
356
+ "\n",
357
+ " process_btn.click(\n",
358
+ " fn=process_document_stream,\n",
359
+ " inputs=[image_input, prompt_input, max_new_tokens, temperature, top_p, top_k, repetition_penalty],\n",
360
+ " outputs=[raw_output_stream, markdown_output]\n",
361
+ " )\n",
362
+ "\n",
363
+ " generate_pdf_btn.click(\n",
364
+ " fn=generate_and_preview_pdf,\n",
365
+ " inputs=[image_input, raw_output_stream, font_size, line_spacing, alignment, image_size],\n",
366
+ " outputs=[pdf_output_file, pdf_preview_gallery]\n",
367
+ " )\n",
368
+ "\n",
369
+ " clear_btn.click(\n",
370
+ " clear_all_outputs,\n",
371
+ " outputs=[image_input, prompt_input, raw_output_stream, markdown_output, pdf_output_file, pdf_preview_gallery]\n",
372
+ " )\n",
373
+ " return demo\n",
374
+ "\n",
375
+ "if __name__ == \"__main__\":\n",
376
+ " demo = create_gradio_interface()\n",
377
+ " demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)"
378
+ ]
379
+ }
380
+ ],
381
+ "metadata": {
382
+ "accelerator": "GPU",
383
+ "colab": {
384
+ "gpuType": "T4",
385
+ "provenance": []
386
+ },
387
+ "kernelspec": {
388
+ "display_name": "Python 3",
389
+ "name": "python3"
390
+ },
391
+ "language_info": {
392
+ "name": "python"
393
+ }
394
+ },
395
+ "nbformat": 4,
396
+ "nbformat_minor": 0
397
+ }