Chhagan005's picture
Update app.py
aa36fcb verified
"""
Chhagan DocVL AI - Document Intelligence Demo
Triple model: CSM-DocExtract-VL + Chhagan_ML-VL-OCR-v1 + Chhagan-DocVL-Qwen3
"""
import os
import time
import json
import logging
import spaces
import gradio as gr
import torch
from peft import PeftModel
from transformers import (
Qwen3VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
AutoProcessor,
AutoTokenizer,
AutoModelForSeq2SeqLM,
)
import utils
logging.getLogger("transformers").setLevel(logging.ERROR)
# ---------- CONFIG ----------
HF_TOKEN = os.getenv("HF_TOKEN")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
MODEL_CONFIGS = {
"Chhagan-DocVL-Qwen3": {
"adapter": "Chhagan005/Chhagan-DocVL-Qwen3",
"base": "Qwen/Qwen3-VL-2B-Instruct",
"base_class": "qwen3",
"use_nllb": False,
},
"Chhagan_ML-VL-OCR-v1": {
"adapter": "Chhagan005/Chhagan_ML-VL-OCR-v1",
"base": "Qwen/Qwen2-VL-2B-Instruct",
"base_class": "qwen2",
"use_nllb": False,
},
"CSM-DocExtract-VL": {
"adapter": None,
"base": "Qwen/Qwen3-VL-2B-Instruct",
"base_class": "qwen3",
"use_nllb": True,
},
}
NLLB_MODEL_ID = "facebook/nllb-200-distilled-600M"
# ---------- PROMPTS ----------
SYSTEM_PROMPT = """You are a precise KYC document extraction AI.
Rules:
- Extract text EXACTLY as printed, never guess or hallucinate
- Return ONLY valid JSON, no extra text
- Confidence: 0.0 (unreadable) to 1.0 (crystal clear)
- For each text field: include original script + ISO-639-1 language code
- MRZ lines: copy CHARACTER-PERFECT including angle brackets '<'
"""
DOC_TYPE_PROMPT = """Examine this document image and return JSON only:
{
"document_type": "<passport|national_id|driving_license|visa|residence_permit|other>",
"issuing_country": "<ISO 3166-1 alpha-3 or full name>",
"document_side": "<front|back|single>",
"has_mrz": false,
"primary_language": "<ISO-639-1>",
"confidence": 0.0
}"""
EXTRACT_PROMPT_TEMPLATE = (
"This is a {doc_type} document (side: {side}).\n"
"Extract ALL visible fields and return JSON only:\n"
"{{\n"
' "surname": {{"value": "", "lang": "", "confidence": 0.0}},\n'
' "given_names": {{"value": "", "lang": "", "confidence": 0.0}},\n'
' "nationality": {{"value": "", "lang": "", "confidence": 0.0}},\n'
' "date_of_birth": {{"value": "", "confidence": 0.0}},\n'
' "document_number":{{"value": "", "confidence": 0.0}},\n'
' "expiry_date": {{"value": "", "confidence": 0.0}},\n'
' "sex": {{"value": "", "confidence": 0.0}},\n'
' "place_of_birth": {{"value": "", "lang": "", "confidence": 0.0}},\n'
' "issue_date": {{"value": "", "confidence": 0.0}},\n'
' "address": {{"value": "", "lang": "", "confidence": 0.0}},\n'
' "mrz_lines": [],\n'
' "other_fields": {{}}\n'
"}}\n"
"Include ONLY fields that are ACTUALLY VISIBLE. Be honest about confidence."
)
# ---------- LOAD ALL MODELS ----------
loaded_models = {}
loaded_processors = {}
nllb_tokenizer = None
nllb_model = None
def load_model(name, config):
print("=" * 60)
print(f"πŸš€ Loading {name} ...")
print("=" * 60)
try:
proc = AutoProcessor.from_pretrained(
config["base"], trust_remote_code=True, token=HF_TOKEN,
)
if config["base_class"] == "qwen3":
base = Qwen3VLForConditionalGeneration.from_pretrained(
config["base"], trust_remote_code=True,
torch_dtype=torch.bfloat16, token=HF_TOKEN,
).to(device)
else:
base = Qwen2VLForConditionalGeneration.from_pretrained(
config["base"], trust_remote_code=True,
torch_dtype=torch.bfloat16, token=HF_TOKEN,
).to(device)
if config["adapter"]:
model = PeftModel.from_pretrained(
base, config["adapter"], token=HF_TOKEN,
).eval()
else:
model = base.eval()
loaded_models[name] = model
loaded_processors[name] = proc
print(f" βœ… {name} loaded")
except Exception as e:
print(f" ❌ Failed to load {name}: {e}")
# Load all VLMs
for name, cfg in MODEL_CONFIGS.items():
load_model(name, cfg)
# Load NLLB (shared, only once)
print("=" * 60)
print("πŸš€ Loading NLLB-200-distilled-600M (translation) ...")
print("=" * 60)
try:
nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL_ID)
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
NLLB_MODEL_ID, torch_dtype=torch.float16
).to(device)
print(" βœ… NLLB loaded")
except Exception as e:
print(f" ⚠️ NLLB failed to load (translation disabled): {e}")
# ---------- NLLB TRANSLATION ----------
LANG_TO_NLLB = {
"ar": "arb_Arab", "fa": "pes_Arab", "ur": "urd_Arab",
"hi": "hin_Deva", "zh": "zho_Hans", "ru": "rus_Cyrl",
"fr": "fra_Latn", "es": "spa_Latn", "de": "deu_Latn",
"tr": "tur_Latn", "id": "ind_Latn", "ms": "zsm_Latn",
}
def translate_to_english(text, src_lang_iso):
if not nllb_model or not text or src_lang_iso == "en":
return text
nllb_src = LANG_TO_NLLB.get(src_lang_iso)
if not nllb_src:
return text
try:
inputs = nllb_tokenizer(
text, return_tensors="pt", padding=True, truncation=True, max_length=256
).to(device)
forced_bos = nllb_tokenizer.lang_code_to_id["eng_Latn"]
with torch.no_grad():
out = nllb_model.generate(
**inputs,
forced_bos_token_id=forced_bos,
max_new_tokens=256,
)
return nllb_tokenizer.decode(out[0], skip_special_tokens=True)
except Exception:
return text
def apply_nllb_to_result(data):
"""Translate all non-English string fields to English."""
translated = {}
for key, val in data.items():
if isinstance(val, dict) and "value" in val and "lang" in val:
src = val.get("lang", "en")
translated_val = translate_to_english(val["value"], src)
translated[key] = {**val, "value_en": translated_val}
else:
translated[key] = val
return translated
# ---------- INFERENCE ----------
def run_inference_step(model, proc, image, prompt, system_prompt=None, temp=0.1):
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
})
text_prompt = proc.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = proc(
text=[text_prompt],
images=[image],
return_tensors="pt",
padding=True,
).to(device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=True,
temperature=temp,
top_p=0.9,
repetition_penalty=1.05,
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return proc.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# ---------- GRADIO EVENTS ----------
@spaces.GPU(duration=120)
def process_document(image, session_state, model_choice):
if image is None:
yield "⚠️ Please upload an image.", session_state, gr.update()
return
if model_choice not in loaded_models:
yield f"❌ Model '{model_choice}' failed to load at startup.", session_state, gr.update()
return
active_model = loaded_models[model_choice]
active_proc = loaded_processors[model_choice]
use_nllb = MODEL_CONFIGS[model_choice]["use_nllb"]
# Step 1: Identify Document
yield "πŸ” Analyzing document type...", session_state, gr.update()
type_response = run_inference_step(
active_model, active_proc, image, DOC_TYPE_PROMPT, SYSTEM_PROMPT, temp=0.1
)
type_data = utils.clean_json_output(type_response)
doc_type = type_data.get("document_type", "document")
doc_side = type_data.get("document_side", "single")
# Step 2: Extract Data
yield f"πŸ“„ Extracting data from {doc_type} ({doc_side})...", session_state, gr.update()
extract_prompt = EXTRACT_PROMPT_TEMPLATE.format(doc_type=doc_type, side=doc_side)
extract_response = run_inference_step(
active_model, active_proc, image, extract_prompt, SYSTEM_PROMPT, temp=0.1
)
extract_data = utils.clean_json_output(extract_response)
# Step 3: MRZ Validation
yield "🧠 Validating and merging results...", session_state, gr.update()
if "mrz_lines" in extract_data and extract_data["mrz_lines"]:
mrz_parsed = utils.parse_mrz(extract_data["mrz_lines"])
extract_data["_mrz_validation"] = mrz_parsed
if mrz_parsed.get("mrz_valid"):
if mrz_parsed.get("document_number"):
extract_data["document_number"] = {"value": mrz_parsed["document_number"], "confidence": 1.0, "source": "MRZ"}
if mrz_parsed.get("date_of_birth"):
extract_data["date_of_birth"] = {"value": mrz_parsed["date_of_birth"], "confidence": 1.0, "source": "MRZ"}
if mrz_parsed.get("expiry_date"):
extract_data["expiry_date"] = {"value": mrz_parsed["expiry_date"], "confidence": 1.0, "source": "MRZ"}
# Step 4: NLLB Translation (CSM mode only)
if use_nllb and nllb_model:
yield "🌐 Translating fields to English (NLLB)...", session_state, gr.update()
extract_data = apply_nllb_to_result(extract_data)
if session_state is None:
session_state = {}
side_key = (
"front" if "front" in doc_side.lower()
else "back" if "back" in doc_side.lower()
else "scan_" + str(len(session_state) + 1)
)
session_state[side_key] = extract_data
final_output = extract_data
if len(session_state) > 1:
merged = {}
for key, data in session_state.items():
if not merged:
merged = data.copy()
else:
if "_mrz_validation" in data and data["_mrz_validation"].get("mrz_valid"):
merged["_mrz_validation"] = data["_mrz_validation"]
merged["document_number"] = data.get("document_number", merged.get("document_number"))
merged["_session_info"] = f"Merged {len(session_state)} scans: {list(session_state.keys())}"
final_output = merged
json_str = json.dumps(final_output, indent=2, ensure_ascii=False)
yield json_str, session_state, gr.JSON(final_output)
def clear_session():
return {}, "Session Cleared", None
# ---------- UI ----------
with gr.Blocks(title="Chhagan DocVL AI Pro") as demo:
state = gr.State({})
gr.Markdown("# πŸ‘οΈ **Chhagan DocVL AI Pro**")
gr.Markdown("### Advanced Document Intelligence Β· Strict JSON Β· MRZ Validation Β· Multilingual Translation")
with gr.Row():
with gr.Column(scale=1):
model_choice = gr.Radio(
choices=list(MODEL_CONFIGS.keys()),
label="Select Model",
value="Chhagan-DocVL-Qwen3",
)
gr.Markdown("""
- **Chhagan-DocVL-Qwen3** β€” Best for structured docs
- **Chhagan_ML-VL-OCR-v1** β€” Best for raw OCR
- **CSM-DocExtract-VL** β€” Full pipeline + NLLB translation (200+ languages)
""")
image_upload = gr.Image(type="pil", label="Upload Document", height=400)
with gr.Row():
submit_btn = gr.Button("πŸš€ Scan Document", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Session", variant="secondary")
gr.Markdown("**Session Status:**")
session_display = gr.JSON(label="Current Session Data", value={})
with gr.Column(scale=1):
gr.Markdown("## Extraction Results")
output_json = gr.JSON(label="Structured Output")
output_text = gr.Code(label="Raw JSON Response", language="json")
submit_btn.click(
fn=process_document,
inputs=[image_upload, state, model_choice],
outputs=[output_text, state, output_json],
).then(
fn=lambda s: s,
inputs=[state],
outputs=[session_display],
)
clear_btn.click(
fn=clear_session,
outputs=[state, output_text, output_json],
)
if __name__ == "__main__":
demo.queue().launch()