pragnesh002's picture
update requirements
4cef604
# ============================================================================
# FILE: app.py (Hugging Face Transformers + GGUF caching + Vision)
# ============================================================================
import gradio as gr
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import fitz
import json
import pandas as pd
import os
import re
import xlsxwriter
import tempfile
from pathlib import Path
import gc
import time
import requests
# ============================================================================
# SYSTEM PROMPT
# ============================================================================
SYSTEM_PROMPT = """You are a data extraction assistant.
Extract item details from the provided text.
Provide output as a JSON array of objects with keys: 'Flag', 'Product Code', 'Description', 'Manufacturer', 'Supplier', 'Material', 'Dimensions', 'Product Image'.
If a key's value is not found, provide empty string "".
If no items found, return empty array [].
Include only unique Product Code values.
For Dimensions, format as "Height: X; Width: Y; Depth: Z" (semicolon-separated).
Do not add duplicate or test data."""
# ============================================================================
# GLOBAL MODELS
# ============================================================================
vision_model = None
vision_processor = None
text_model = None
text_tokenizer = None
# ============================================================================
# GGUF CACHING
# ============================================================================
def get_gguf_local(model_id, filename="unsloth.Q4_K_M.gguf"):
"""Download GGUF file once and cache locally."""
cache_dir = Path("/tmp/gguf_cache")
cache_dir.mkdir(parents=True, exist_ok=True)
local_file = cache_dir / filename
if not local_file.exists():
url = f"https://huggingface.co/{model_id}/resolve/main/{filename}"
print(f"πŸ“₯ Downloading GGUF file from Hugging Face: {url}")
r = requests.get(url)
r.raise_for_status()
with open(local_file, "wb") as f:
f.write(r.content)
print("βœ… GGUF file downloaded and cached locally.")
else:
print("βœ… Using cached GGUF file.")
return str(local_file)
# ============================================================================
# MODEL LOADERS
# ============================================================================
def load_vision_model():
"""Load vision model lazily"""
global vision_model, vision_processor
if vision_model is None:
print("πŸ“Έ Loading vision model...")
vision_processor = LlavaNextProcessor.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf"
)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
torch_dtype=torch.float16,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
)
print("βœ… Vision model loaded!")
return vision_model, vision_processor
def load_text_model():
"""Load GGUF text model with caching"""
global text_model, text_tokenizer
if text_model is None:
model_id = "pragnesh002/Qwen3-4B-Product-Extractor-GGUF-Q4-K-M"
gguf_local = get_gguf_local(model_id)
print("πŸ“ Loading GGUF text extraction model...")
text_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
text_model = AutoModelForCausalLM.from_pretrained(
gguf_local,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True
)
print("βœ… GGUF text model loaded via Transformers!")
return text_model, text_tokenizer
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def cleanup_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def extract_pdf_text(pdf_path):
"""Extract text from PDF"""
try:
doc = fitz.open(pdf_path)
pages_text = []
for page_num in range(doc.page_count):
page = doc.load_page(page_num)
text = page.get_text().strip()
if len(text) < 50:
text = f"[Page {page_num + 1} - Low text content]"
pages_text.append(text)
doc.close()
return pages_text
except Exception as e:
return [f"Error extracting text: {str(e)}"]
def extract_products_from_text(page_text, page_num):
"""Extract product data using Transformers LLM"""
text_model, text_tokenizer = load_text_model()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Text:\n{page_text[:2000]}\n\nOutput JSON:"}
]
try:
# Use chat template if available
if hasattr(text_tokenizer, "apply_chat_template"):
prompt = text_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
prompt = messages[-1]["content"]
inputs = text_tokenizer(prompt, return_tensors="pt").to(text_model.device)
with torch.no_grad():
outputs = text_model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.1,
do_sample=False,
pad_token_id=text_tokenizer.eos_token_id
)
output_text = text_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Extract JSON
json_match = re.search(r'```json\s*(.*?)\s*```', output_text, re.DOTALL)
if json_match:
json_str = json_match.group(1)
else:
json_match = re.search(r'(\[.*\]|\{.*\})', output_text, re.DOTALL)
json_str = json_match.group(1) if json_match else output_text
parsed = json.loads(json_str)
if isinstance(parsed, dict):
parsed = [parsed]
elif not isinstance(parsed, list):
parsed = []
del inputs, outputs
cleanup_memory()
return parsed
except Exception as e:
print(f"Error extracting from page {page_num}: {e}")
cleanup_memory()
return []
# For brevity, vision analysis, image extraction, Excel creation functions
# can remain unchanged from your last version. Use your previous implementations:
# - analyze_image_with_vision_model()
# - extract_images_from_page()
# - create_excel_with_images()
# ============================================================================
# MAIN PROCESSING FUNCTION
# ============================================================================
def process_pdf(pdf_file, max_pages, progress=gr.Progress()):
if pdf_file is None:
return None, "⚠️ Please upload a PDF file first"
progress(0, desc="Initializing...")
try:
load_text_model()
except Exception as e:
return None, f"❌ Error loading text model: {str(e)}"
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
img_dir = temp_dir / "images"
img_dir.mkdir()
try:
pages_text = extract_pdf_text(pdf_file.name)
doc = fitz.open(pdf_file.name)
all_products = {}
product_images = {}
total_pages = min(len(pages_text), doc.page_count, max_pages)
if total_pages > 20:
doc.close()
return None, f"⚠️ PDF has {total_pages} pages. Limit to 20 pages."
for page_num in range(total_pages):
progress(0.2 + (0.6 * page_num / total_pages), desc=f"Processing page {page_num+1}/{total_pages}...")
products = extract_products_from_text(pages_text[page_num], page_num)
# Images & matching logic (use your previous code)
# ...
# Store products
for product in products:
code = product.get('Product Code', '').strip()
if code and code not in all_products:
product['Product Image File'] = product_images.get(code, '')
all_products[code] = product
cleanup_memory()
time.sleep(0.5)
doc.close()
progress(0.95, desc="Creating Excel...")
if all_products:
output_excel = temp_dir / "products_with_images.xlsx"
create_excel_with_images(list(all_products.values()), str(output_excel))
total_products = len(all_products)
products_with_images = sum(1 for p in all_products.values() if p.get('Product Image File'))
summary = f"""
## βœ… Extraction Complete!
- **Total products found:** {total_products}
- **Products with images:** {products_with_images}
- **Pages processed:** {total_pages}
- **Image match rate:** {(products_with_images/total_products*100):.1f}%
### Download your Excel file below! πŸ“₯
"""
progress(1.0, desc="βœ… Done!")
return str(output_excel), summary
else:
return None, "⚠️ No products found in PDF."
except Exception as e:
import traceback
return None, f"❌ Error: {str(e)}\n```\n{traceback.format_exc()}\n```"
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
with gr.Blocks(title="Product Data Extractor", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ“¦ Product Data Extractor with AI Vision
Automatically extract product info and images from PDF using LLaVA + Qwen GGUF.
""")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(label="πŸ“„ Upload PDF", file_types=[".pdf"], type="filepath")
max_pages_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Max Pages")
extract_btn = gr.Button("πŸš€ Extract Products", variant="primary", size="lg")
with gr.Column(scale=1):
summary_output = gr.Markdown(label="Results")
excel_output = gr.File(label="πŸ“₯ Download Excel")
extract_btn.click(fn=process_pdf, inputs=[pdf_input, max_pages_slider], outputs=[excel_output, summary_output])
# Launch
if __name__ == "__main__":
demo.queue(max_size=5)
demo.launch()