Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from pathlib import Path | |
| from PIL import Image | |
| import PyPDF2 | |
| import docx | |
| from sentence_transformers import SentenceTransformer, util | |
| import faiss | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BlipProcessor, BlipForConditionalGeneration | |
| import torch | |
| from datetime import datetime | |
| import fitz # PyMuPDF | |
| import shutil | |
| # Load models | |
| print("Loading models...") | |
| embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| print("Loading LLM...") | |
| tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| print("Loading image caption model...") | |
| caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
| caption_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-large", | |
| torch_dtype=torch.float16 | |
| ).to("cuda" if torch.cuda.is_available() else "cpu") | |
| print("β All models loaded!") | |
| # Storage | |
| documents = [] | |
| images = [] | |
| image_captions = [] | |
| embeddings_index = None | |
| def generate_image_caption(image_path): | |
| """Generate detailed caption for image""" | |
| try: | |
| img = Image.open(image_path).convert('RGB') | |
| # Generate detailed caption | |
| inputs = caption_processor(img, return_tensors="pt").to(caption_model.device) | |
| output = caption_model.generate( | |
| **inputs, | |
| max_length=100, | |
| num_beams=5, | |
| temperature=0.7 | |
| ) | |
| caption = caption_processor.decode(output[0], skip_special_tokens=True) | |
| return caption.strip() | |
| except Exception as e: | |
| print(f"Caption error: {e}") | |
| return "" | |
| def extract_images_from_pdf(pdf_path): | |
| """Extract images from PDF""" | |
| extracted = [] | |
| try: | |
| doc = fitz.open(pdf_path) | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| images_list = page.get_images(full=True) | |
| for img_index, img in enumerate(images_list): | |
| try: | |
| xref = img[0] | |
| base_image = doc.extract_image(xref) | |
| image_bytes = base_image["image"] | |
| # Save image | |
| img_path = f"/tmp/pdf_page{page_num+1}_img{img_index}.png" | |
| with open(img_path, "wb") as f: | |
| f.write(image_bytes) | |
| # Validate image | |
| test_img = Image.open(img_path) | |
| width, height = test_img.size | |
| # Only keep meaningful images (not tiny icons/logos) | |
| if width >= 150 and height >= 150: | |
| extracted.append({ | |
| 'path': img_path, | |
| 'page': page_num + 1, | |
| 'source': Path(pdf_path).name | |
| }) | |
| except Exception as e: | |
| continue | |
| doc.close() | |
| except Exception as e: | |
| print(f"PDF image extraction error: {e}") | |
| return extracted | |
| def extract_pdf_text(pdf_path): | |
| """Extract text from PDF""" | |
| chunks = [] | |
| with open(pdf_path, 'rb') as f: | |
| pdf = PyPDF2.PdfReader(f) | |
| for i, page in enumerate(pdf.pages): | |
| text = page.extract_text() | |
| if text.strip(): | |
| chunks.append({ | |
| 'text': text, | |
| 'page': i + 1, | |
| 'source': Path(pdf_path).name | |
| }) | |
| return chunks | |
| def extract_docx_text(docx_path): | |
| doc = docx.Document(docx_path) | |
| text = '\n'.join([p.text for p in doc.paragraphs if p.text.strip()]) | |
| return [{'text': text, 'source': Path(docx_path).name}] | |
| def extract_txt_text(txt_path): | |
| with open(txt_path, 'r', encoding='utf-8') as f: | |
| return [{'text': f.read(), 'source': Path(txt_path).name}] | |
| def chunk_text(text, size=400): | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), size): | |
| chunk = ' '.join(words[i:i+size]) | |
| if chunk.strip(): | |
| chunks.append(chunk) | |
| return chunks | |
| def process_files(files, progress=gr.Progress()): | |
| """Process uploaded files""" | |
| global documents, images, image_captions, embeddings_index | |
| if not files: | |
| return "β οΈ Please upload files first" | |
| documents = [] | |
| images = [] | |
| image_captions = [] | |
| total = len(files) | |
| for idx, file in enumerate(files): | |
| progress((idx + 1) / total, desc=f"Processing {Path(file.name).name}...") | |
| ext = Path(file.name).suffix.lower() | |
| if ext == '.pdf': | |
| # Extract text | |
| chunks = extract_pdf_text(file.name) | |
| for chunk in chunks: | |
| for small_chunk in chunk_text(chunk['text']): | |
| documents.append({ | |
| 'text': small_chunk, | |
| 'source': chunk['source'], | |
| 'page': chunk['page'] | |
| }) | |
| # Extract images | |
| pdf_images = extract_images_from_pdf(file.name) | |
| for img in pdf_images: | |
| caption = generate_image_caption(img['path']) | |
| if caption: # Only add if caption generated | |
| images.append(img) | |
| image_captions.append(caption) | |
| elif ext == '.docx': | |
| chunks = extract_docx_text(file.name) | |
| for chunk in chunks: | |
| for small_chunk in chunk_text(chunk['text']): | |
| documents.append({ | |
| 'text': small_chunk, | |
| 'source': chunk['source'] | |
| }) | |
| elif ext == '.txt': | |
| chunks = extract_txt_text(file.name) | |
| for chunk in chunks: | |
| for small_chunk in chunk_text(chunk['text']): | |
| documents.append({ | |
| 'text': small_chunk, | |
| 'source': chunk['source'] | |
| }) | |
| elif ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']: | |
| caption = generate_image_caption(file.name) | |
| if caption: | |
| images.append({ | |
| 'path': file.name, | |
| 'source': Path(file.name).name, | |
| 'page': '' | |
| }) | |
| image_captions.append(caption) | |
| # Create embeddings | |
| progress(0.9, desc="Creating embeddings...") | |
| if documents: | |
| texts = [doc['text'] for doc in documents] | |
| embeddings = embedding_model.encode(texts, show_progress_bar=False) | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings.astype('float32')) | |
| embeddings_index = index | |
| progress(1.0, desc="Done!") | |
| status = f"β **Processing Complete!**\n\n" | |
| status += f"π **Text chunks:** {len(documents)}\n" | |
| status += f"πΌοΈ **Images extracted:** {len(images)}\n" | |
| if images: | |
| status += f"\n**Sample captions:**\n" | |
| for i, (img, cap) in enumerate(zip(images[:3], image_captions[:3]), 1): | |
| status += f"{i}. {img['source']}" | |
| if img.get('page'): | |
| status += f" (Page {img['page']})" | |
| status += f":\n _{cap}_\n" | |
| return status | |
| def search_documents(query, k=3): | |
| """Search relevant documents""" | |
| if not documents or embeddings_index is None: | |
| return [] | |
| query_vec = embedding_model.encode([query]) | |
| distances, indices = embeddings_index.search(query_vec.astype('float32'), k) | |
| results = [] | |
| for idx in indices[0]: | |
| if idx < len(documents): | |
| results.append(documents[idx]) | |
| return results | |
| def find_relevant_images(query, relevance_threshold=0.25): | |
| """Find images ONLY if relevant to query""" | |
| if not images or not image_captions: | |
| return [], [] | |
| # Encode query and captions | |
| query_emb = embedding_model.encode(query, convert_to_tensor=True) | |
| caption_embs = embedding_model.encode(image_captions, convert_to_tensor=True) | |
| # Calculate cosine similarity | |
| similarities = util.cos_sim(query_emb, caption_embs)[0] | |
| # Filter by threshold and get top 3 | |
| relevant_imgs = [] | |
| explanations = [] | |
| for idx, sim_score in enumerate(similarities): | |
| sim_value = float(sim_score) | |
| # Only show if relevance > threshold | |
| if sim_value > relevance_threshold: | |
| img_info = images[idx] | |
| caption = image_captions[idx] | |
| relevant_imgs.append(img_info['path']) | |
| # Create explanation | |
| exp = f"**π Source:** {img_info['source']}" | |
| if img_info.get('page'): | |
| exp += f" (Page {img_info['page']})" | |
| exp += f"\n**π¬ Description:** {caption}" | |
| exp += f"\n**π― Relevance:** {sim_value * 100:.1f}%\n" | |
| explanations.append(exp) | |
| # Sort by relevance and take top 3 | |
| if relevant_imgs: | |
| sorted_pairs = sorted( | |
| zip(similarities, relevant_imgs, explanations), | |
| key=lambda x: x[0], | |
| reverse=True | |
| )[:3] | |
| relevant_imgs = [pair[1] for pair in sorted_pairs] | |
| explanations = [pair[2] for pair in sorted_pairs] | |
| return relevant_imgs, explanations | |
| def generate_answer(question, context_docs): | |
| """Generate answer from context""" | |
| context = '\n\n'.join([doc['text'] for doc in context_docs]) | |
| prompt = f"""Answer this question based only on the context provided. Be concise and accurate. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1200) | |
| with torch.no_grad(): | |
| outputs = llm_model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract answer part | |
| if "Answer:" in answer: | |
| answer = answer.split("Answer:")[-1].strip() | |
| return answer | |
| def answer_query(question, progress=gr.Progress()): | |
| """Answer question with relevant images only""" | |
| if not question.strip(): | |
| return "β οΈ Please enter a question", None | |
| if not documents: | |
| return "β οΈ Please upload and process documents first", None | |
| # Search documents | |
| progress(0.3, desc="Searching documents...") | |
| relevant_docs = search_documents(question, k=3) | |
| if not relevant_docs: | |
| return "β No relevant information found", None | |
| # Generate answer | |
| progress(0.6, desc="Generating answer...") | |
| answer = generate_answer(question, relevant_docs) | |
| # Format response | |
| response = f"## π‘ Answer\n\n{answer}\n\n" | |
| response += f"## π Text Sources\n\n" | |
| for i, doc in enumerate(relevant_docs, 1): | |
| source = doc['source'] | |
| page = doc.get('page', '') | |
| if page: | |
| response += f"{i}. **{source}** (Page {page})\n" | |
| else: | |
| response += f"{i}. **{source}**\n" | |
| # Find relevant images | |
| progress(0.9, desc="Finding relevant images...") | |
| relevant_imgs, img_explanations = find_relevant_images(question, relevance_threshold=0.25) | |
| # Add image explanations if found | |
| if relevant_imgs and img_explanations: | |
| response += f"\n## πΌοΈ Related Images\n\n" | |
| for exp in img_explanations: | |
| response += f"{exp}\n" | |
| else: | |
| response += f"\n_No relevant images found for this query_\n" | |
| progress(1.0, desc="Done!") | |
| return response, relevant_imgs if relevant_imgs else None | |
| # UI | |
| with gr.Blocks( | |
| title="DocVision AI", | |
| theme=gr.themes.Soft(primary_hue="indigo") | |
| ) as app: | |
| gr.Markdown(""" | |
| # π DocVision AI - Intelligent Document Q&A | |
| ### Upload documents and get AI-powered answers with relevant images | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_input = gr.File( | |
| label="π Upload Documents & Images", | |
| file_count="multiple", | |
| file_types=[".pdf", ".docx", ".txt", ".jpg", ".png"] | |
| ) | |
| process_btn = gr.Button( | |
| "β‘ Process Documents", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| status = gr.Markdown(label="π Processing Status") | |
| with gr.Column(): | |
| question = gr.Textbox( | |
| label="β Ask Your Question", | |
| placeholder="What would you like to know about your documents?", | |
| lines=3 | |
| ) | |
| ask_btn = gr.Button( | |
| "π Get Answer", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| answer = gr.Markdown(label="π Answer with Sources") | |
| gallery = gr.Gallery( | |
| label="πΌοΈ Relevant Images (Only shown if related to your question)", | |
| columns=2, | |
| height=500, | |
| show_label=True | |
| ) | |
| gr.Markdown("### π‘ Example Questions") | |
| gr.Examples( | |
| examples=[ | |
| ["What is the main topic of this document?"], | |
| ["Explain the workflow or architecture shown"], | |
| ["What are the key findings?"], | |
| ["Describe any diagrams or charts present"] | |
| ], | |
| inputs=question | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| process_files, | |
| inputs=[file_input], | |
| outputs=[status] | |
| ) | |
| ask_btn.click( | |
| answer_query, | |
| inputs=[question], | |
| outputs=[answer, gallery] | |
| ) | |
| question.submit( | |
| answer_query, | |
| inputs=[question], | |
| outputs=[answer, gallery] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() |