ignaciaginting commited on
Commit
4c8d2d0
Β·
verified Β·
1 Parent(s): 6387027

check the colpali first

Browse files
Files changed (1) hide show
  1. app.py +58 -67
app.py CHANGED
@@ -1,95 +1,86 @@
1
  import streamlit as st
2
- import fitz # PyMuPDF
3
  import torch
4
  from PIL import Image
5
- import io
6
- from sentence_transformers import SentenceTransformer, util
7
  from transformers.utils.import_utils import is_flash_attn_2_available
8
  from colpali_engine.models import ColQwen2, ColQwen2Processor
9
 
10
  # -----------------------------
11
- # Load models
12
  # -----------------------------
13
  @st.cache_resource
14
- def load_models():
15
- text_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
16
- colpali_model = ColQwen2.from_pretrained(
17
- "vidore/colqwen2-v1.0",
18
  torch_dtype=torch.bfloat16,
19
  device_map="cuda:0" if torch.cuda.is_available() else "cpu",
20
- attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None
21
  ).eval()
22
- colpali_processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
23
- return text_model, colpali_model, colpali_processor
24
 
25
- text_model, colpali_model, colpali_processor = load_models()
26
 
27
- # -----------------------------
28
- # UI Elements
29
- # -----------------------------
30
- st.title("πŸ“„ Chat with Your Financial Report (PDF + Table + Image)")
31
- pdf_file = st.file_uploader("Upload your PDF", type="pdf")
32
- use_colpali = st.checkbox("Enable ColPali (for image tables)", value=True)
33
 
34
  # -----------------------------
35
- # Process PDF
36
  # -----------------------------
37
- if pdf_file:
38
- doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
39
- text_chunks = []
40
  images = []
41
-
42
  for page in doc:
43
- blocks = page.get_text("blocks")
44
- for block in blocks:
45
- if block[4].strip():
46
- text_chunks.append(block[4].strip())
47
-
48
- # Extract images if ColPali is enabled
49
- if use_colpali:
50
- for img_index, img in enumerate(page.get_images(full=True)):
51
- xref = img[0]
52
- base_image = doc.extract_image(xref)
53
- image_bytes = base_image["image"]
54
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
55
- images.append(image)
56
 
57
- # Embed all text chunks
58
- text_embeddings = text_model.encode(text_chunks, convert_to_tensor=True)
 
 
 
 
 
59
 
60
- if use_colpali and images:
61
- image_inputs = colpali_processor.process_images(images).to(colpali_model.device)
62
- with torch.no_grad():
63
- image_embeddings = colpali_model(**image_inputs)
64
- else:
65
- image_embeddings = None
 
 
 
 
66
 
67
- # -----------------------------
68
- # Chat Interface
69
- # -----------------------------
70
- user_query = st.text_input("Ask a question about your PDF:")
71
 
72
- if user_query:
73
- st.write("πŸ” Searching for answers...")
 
 
 
 
74
 
75
- # Text-based search
76
- query_embedding = text_model.encode(user_query, convert_to_tensor=True)
77
- top_text_hits = util.semantic_search(query_embedding, text_embeddings, top_k=3)[0]
78
 
79
- st.markdown("### πŸ“ Top Text Answers")
80
- for hit in top_text_hits:
81
- score = hit['score']
82
- chunk = text_chunks[hit['corpus_id']]
83
- st.markdown(f"**Score:** {score:.4f}\n\n{chunk}")
84
 
85
- # Image-based search (ColPali)
86
- if use_colpali and image_embeddings is not None:
87
- query_vec = colpali_processor.process_queries([user_query]).to(colpali_model.device)
88
  with torch.no_grad():
89
- query_embedding_img = colpali_model(**query_vec)
 
 
 
 
 
 
90
 
91
- scores = colpali_processor.score_multi_vector(query_embedding_img, image_embeddings)
92
- top_k = torch.topk(scores, k=min(3, len(images)))
93
- st.markdown("### πŸ–ΌοΈ Top Image/Table Matches")
94
- for idx, score in zip(top_k.indices, top_k.values):
95
- st.image(images[idx], caption=f"Similarity Score: {score.item():.4f}", use_column_width=True)
 
1
  import streamlit as st
 
2
  import torch
3
  from PIL import Image
4
+ import fitz # PyMuPDF
 
5
  from transformers.utils.import_utils import is_flash_attn_2_available
6
  from colpali_engine.models import ColQwen2, ColQwen2Processor
7
 
8
  # -----------------------------
9
+ # Load ColPali Model
10
  # -----------------------------
11
  @st.cache_resource
12
+ def load_colpali():
13
+ model_name = "vidore/colqwen2-v1.0"
14
+ model = ColQwen2.from_pretrained(
15
+ model_name,
16
  torch_dtype=torch.bfloat16,
17
  device_map="cuda:0" if torch.cuda.is_available() else "cpu",
18
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
19
  ).eval()
20
+ processor = ColQwen2Processor.from_pretrained(model_name)
21
+ return model, processor
22
 
23
+ colpali_model, colpali_processor = load_colpali()
24
 
25
+ st.title("πŸ” Visual PDF Search with ColPali")
26
+ pdf_file = st.file_uploader("Upload a PDF", type="pdf")
 
 
 
 
27
 
28
  # -----------------------------
29
+ # Convert PDF to image
30
  # -----------------------------
31
+ def render_pdf_page_as_image(doc, zoom=2.0):
 
 
32
  images = []
 
33
  for page in doc:
34
+ mat = fitz.Matrix(zoom, zoom)
35
+ pix = page.get_pixmap(matrix=mat)
36
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
37
+ images.append(img)
38
+ return images
 
 
 
 
 
 
 
 
39
 
40
+ # -----------------------------
41
+ # Chunk image into pieces
42
+ # -----------------------------
43
+ def chunk_image(image, rows=2, cols=2):
44
+ width, height = image.size
45
+ chunk_width = width // cols
46
+ chunk_height = height // rows
47
 
48
+ chunks = []
49
+ for r in range(rows):
50
+ for c in range(cols):
51
+ left = c * chunk_width
52
+ top = r * chunk_height
53
+ right = left + chunk_width
54
+ bottom = top + chunk_height
55
+ chunk = image.crop((left, top, right, bottom)).resize((512, 512))
56
+ chunks.append(chunk)
57
+ return chunks
58
 
59
+ if pdf_file:
60
+ doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
61
+ images = render_pdf_page_as_image(doc)
 
62
 
63
+ if not images:
64
+ st.warning("Failed to read content from the PDF.")
65
+ else:
66
+ all_chunks = []
67
+ for image in images:
68
+ all_chunks.extend(chunk_image(image, rows=2, cols=2))
69
 
70
+ user_query = st.text_input("What are you looking for in the document?")
 
 
71
 
72
+ if user_query:
73
+ batch_images = colpali_processor.process_images(all_chunks).to(colpali_model.device)
74
+ batch_queries = colpali_processor.process_queries([user_query]).to(colpali_model.device)
 
 
75
 
 
 
 
76
  with torch.no_grad():
77
+ image_embeddings = colpali_model(**batch_images)
78
+ query_embeddings = colpali_model(**batch_queries)
79
+
80
+ scores = colpali_processor.score_multi_vector(query_embeddings, image_embeddings)
81
+ best_idx = torch.argmax(scores).item()
82
+ best_image = all_chunks[best_idx]
83
+ best_score = scores[0, best_idx].item()
84
 
85
+ st.markdown("### πŸ” Most Relevant Image Chunk")
86
+ st.image(best_image, caption=f"Score: {best_score:.4f}", use_column_width=True)