Spaces:
Sleeping
Sleeping
| """ | |
| Chhagan DocVL AI - Document Intelligence Demo | |
| Triple model: CSM-DocExtract-VL + Chhagan_ML-VL-OCR-v1 + Chhagan-DocVL-Qwen3 | |
| """ | |
| import os | |
| import time | |
| import json | |
| import logging | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from peft import PeftModel | |
| from transformers import ( | |
| Qwen3VLForConditionalGeneration, | |
| Qwen2VLForConditionalGeneration, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| ) | |
| import utils | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| # ---------- CONFIG ---------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| MODEL_CONFIGS = { | |
| "Chhagan-DocVL-Qwen3": { | |
| "adapter": "Chhagan005/Chhagan-DocVL-Qwen3", | |
| "base": "Qwen/Qwen3-VL-2B-Instruct", | |
| "base_class": "qwen3", | |
| "use_nllb": False, | |
| }, | |
| "Chhagan_ML-VL-OCR-v1": { | |
| "adapter": "Chhagan005/Chhagan_ML-VL-OCR-v1", | |
| "base": "Qwen/Qwen2-VL-2B-Instruct", | |
| "base_class": "qwen2", | |
| "use_nllb": False, | |
| }, | |
| "CSM-DocExtract-VL": { | |
| "adapter": None, | |
| "base": "Qwen/Qwen3-VL-2B-Instruct", | |
| "base_class": "qwen3", | |
| "use_nllb": True, | |
| }, | |
| } | |
| NLLB_MODEL_ID = "facebook/nllb-200-distilled-600M" | |
| # ---------- PROMPTS ---------- | |
| SYSTEM_PROMPT = """You are a precise KYC document extraction AI. | |
| Rules: | |
| - Extract text EXACTLY as printed, never guess or hallucinate | |
| - Return ONLY valid JSON, no extra text | |
| - Confidence: 0.0 (unreadable) to 1.0 (crystal clear) | |
| - For each text field: include original script + ISO-639-1 language code | |
| - MRZ lines: copy CHARACTER-PERFECT including angle brackets '<' | |
| """ | |
| DOC_TYPE_PROMPT = """Examine this document image and return JSON only: | |
| { | |
| "document_type": "<passport|national_id|driving_license|visa|residence_permit|other>", | |
| "issuing_country": "<ISO 3166-1 alpha-3 or full name>", | |
| "document_side": "<front|back|single>", | |
| "has_mrz": false, | |
| "primary_language": "<ISO-639-1>", | |
| "confidence": 0.0 | |
| }""" | |
| EXTRACT_PROMPT_TEMPLATE = ( | |
| "This is a {doc_type} document (side: {side}).\n" | |
| "Extract ALL visible fields and return JSON only:\n" | |
| "{{\n" | |
| ' "surname": {{"value": "", "lang": "", "confidence": 0.0}},\n' | |
| ' "given_names": {{"value": "", "lang": "", "confidence": 0.0}},\n' | |
| ' "nationality": {{"value": "", "lang": "", "confidence": 0.0}},\n' | |
| ' "date_of_birth": {{"value": "", "confidence": 0.0}},\n' | |
| ' "document_number":{{"value": "", "confidence": 0.0}},\n' | |
| ' "expiry_date": {{"value": "", "confidence": 0.0}},\n' | |
| ' "sex": {{"value": "", "confidence": 0.0}},\n' | |
| ' "place_of_birth": {{"value": "", "lang": "", "confidence": 0.0}},\n' | |
| ' "issue_date": {{"value": "", "confidence": 0.0}},\n' | |
| ' "address": {{"value": "", "lang": "", "confidence": 0.0}},\n' | |
| ' "mrz_lines": [],\n' | |
| ' "other_fields": {{}}\n' | |
| "}}\n" | |
| "Include ONLY fields that are ACTUALLY VISIBLE. Be honest about confidence." | |
| ) | |
| # ---------- LOAD ALL MODELS ---------- | |
| loaded_models = {} | |
| loaded_processors = {} | |
| nllb_tokenizer = None | |
| nllb_model = None | |
| def load_model(name, config): | |
| print("=" * 60) | |
| print(f"π Loading {name} ...") | |
| print("=" * 60) | |
| try: | |
| proc = AutoProcessor.from_pretrained( | |
| config["base"], trust_remote_code=True, token=HF_TOKEN, | |
| ) | |
| if config["base_class"] == "qwen3": | |
| base = Qwen3VLForConditionalGeneration.from_pretrained( | |
| config["base"], trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, token=HF_TOKEN, | |
| ).to(device) | |
| else: | |
| base = Qwen2VLForConditionalGeneration.from_pretrained( | |
| config["base"], trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, token=HF_TOKEN, | |
| ).to(device) | |
| if config["adapter"]: | |
| model = PeftModel.from_pretrained( | |
| base, config["adapter"], token=HF_TOKEN, | |
| ).eval() | |
| else: | |
| model = base.eval() | |
| loaded_models[name] = model | |
| loaded_processors[name] = proc | |
| print(f" β {name} loaded") | |
| except Exception as e: | |
| print(f" β Failed to load {name}: {e}") | |
| # Load all VLMs | |
| for name, cfg in MODEL_CONFIGS.items(): | |
| load_model(name, cfg) | |
| # Load NLLB (shared, only once) | |
| print("=" * 60) | |
| print("π Loading NLLB-200-distilled-600M (translation) ...") | |
| print("=" * 60) | |
| try: | |
| nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL_ID) | |
| nllb_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| NLLB_MODEL_ID, torch_dtype=torch.float16 | |
| ).to(device) | |
| print(" β NLLB loaded") | |
| except Exception as e: | |
| print(f" β οΈ NLLB failed to load (translation disabled): {e}") | |
| # ---------- NLLB TRANSLATION ---------- | |
| LANG_TO_NLLB = { | |
| "ar": "arb_Arab", "fa": "pes_Arab", "ur": "urd_Arab", | |
| "hi": "hin_Deva", "zh": "zho_Hans", "ru": "rus_Cyrl", | |
| "fr": "fra_Latn", "es": "spa_Latn", "de": "deu_Latn", | |
| "tr": "tur_Latn", "id": "ind_Latn", "ms": "zsm_Latn", | |
| } | |
| def translate_to_english(text, src_lang_iso): | |
| if not nllb_model or not text or src_lang_iso == "en": | |
| return text | |
| nllb_src = LANG_TO_NLLB.get(src_lang_iso) | |
| if not nllb_src: | |
| return text | |
| try: | |
| inputs = nllb_tokenizer( | |
| text, return_tensors="pt", padding=True, truncation=True, max_length=256 | |
| ).to(device) | |
| forced_bos = nllb_tokenizer.lang_code_to_id["eng_Latn"] | |
| with torch.no_grad(): | |
| out = nllb_model.generate( | |
| **inputs, | |
| forced_bos_token_id=forced_bos, | |
| max_new_tokens=256, | |
| ) | |
| return nllb_tokenizer.decode(out[0], skip_special_tokens=True) | |
| except Exception: | |
| return text | |
| def apply_nllb_to_result(data): | |
| """Translate all non-English string fields to English.""" | |
| translated = {} | |
| for key, val in data.items(): | |
| if isinstance(val, dict) and "value" in val and "lang" in val: | |
| src = val.get("lang", "en") | |
| translated_val = translate_to_english(val["value"], src) | |
| translated[key] = {**val, "value_en": translated_val} | |
| else: | |
| translated[key] = val | |
| return translated | |
| # ---------- INFERENCE ---------- | |
| def run_inference_step(model, proc, image, prompt, system_prompt=None, temp=0.1): | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }) | |
| text_prompt = proc.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = proc( | |
| text=[text_prompt], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| temperature=temp, | |
| top_p=0.9, | |
| repetition_penalty=1.05, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| return proc.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| # ---------- GRADIO EVENTS ---------- | |
| def process_document(image, session_state, model_choice): | |
| if image is None: | |
| yield "β οΈ Please upload an image.", session_state, gr.update() | |
| return | |
| if model_choice not in loaded_models: | |
| yield f"β Model '{model_choice}' failed to load at startup.", session_state, gr.update() | |
| return | |
| active_model = loaded_models[model_choice] | |
| active_proc = loaded_processors[model_choice] | |
| use_nllb = MODEL_CONFIGS[model_choice]["use_nllb"] | |
| # Step 1: Identify Document | |
| yield "π Analyzing document type...", session_state, gr.update() | |
| type_response = run_inference_step( | |
| active_model, active_proc, image, DOC_TYPE_PROMPT, SYSTEM_PROMPT, temp=0.1 | |
| ) | |
| type_data = utils.clean_json_output(type_response) | |
| doc_type = type_data.get("document_type", "document") | |
| doc_side = type_data.get("document_side", "single") | |
| # Step 2: Extract Data | |
| yield f"π Extracting data from {doc_type} ({doc_side})...", session_state, gr.update() | |
| extract_prompt = EXTRACT_PROMPT_TEMPLATE.format(doc_type=doc_type, side=doc_side) | |
| extract_response = run_inference_step( | |
| active_model, active_proc, image, extract_prompt, SYSTEM_PROMPT, temp=0.1 | |
| ) | |
| extract_data = utils.clean_json_output(extract_response) | |
| # Step 3: MRZ Validation | |
| yield "π§ Validating and merging results...", session_state, gr.update() | |
| if "mrz_lines" in extract_data and extract_data["mrz_lines"]: | |
| mrz_parsed = utils.parse_mrz(extract_data["mrz_lines"]) | |
| extract_data["_mrz_validation"] = mrz_parsed | |
| if mrz_parsed.get("mrz_valid"): | |
| if mrz_parsed.get("document_number"): | |
| extract_data["document_number"] = {"value": mrz_parsed["document_number"], "confidence": 1.0, "source": "MRZ"} | |
| if mrz_parsed.get("date_of_birth"): | |
| extract_data["date_of_birth"] = {"value": mrz_parsed["date_of_birth"], "confidence": 1.0, "source": "MRZ"} | |
| if mrz_parsed.get("expiry_date"): | |
| extract_data["expiry_date"] = {"value": mrz_parsed["expiry_date"], "confidence": 1.0, "source": "MRZ"} | |
| # Step 4: NLLB Translation (CSM mode only) | |
| if use_nllb and nllb_model: | |
| yield "π Translating fields to English (NLLB)...", session_state, gr.update() | |
| extract_data = apply_nllb_to_result(extract_data) | |
| if session_state is None: | |
| session_state = {} | |
| side_key = ( | |
| "front" if "front" in doc_side.lower() | |
| else "back" if "back" in doc_side.lower() | |
| else "scan_" + str(len(session_state) + 1) | |
| ) | |
| session_state[side_key] = extract_data | |
| final_output = extract_data | |
| if len(session_state) > 1: | |
| merged = {} | |
| for key, data in session_state.items(): | |
| if not merged: | |
| merged = data.copy() | |
| else: | |
| if "_mrz_validation" in data and data["_mrz_validation"].get("mrz_valid"): | |
| merged["_mrz_validation"] = data["_mrz_validation"] | |
| merged["document_number"] = data.get("document_number", merged.get("document_number")) | |
| merged["_session_info"] = f"Merged {len(session_state)} scans: {list(session_state.keys())}" | |
| final_output = merged | |
| json_str = json.dumps(final_output, indent=2, ensure_ascii=False) | |
| yield json_str, session_state, gr.JSON(final_output) | |
| def clear_session(): | |
| return {}, "Session Cleared", None | |
| # ---------- UI ---------- | |
| with gr.Blocks(title="Chhagan DocVL AI Pro") as demo: | |
| state = gr.State({}) | |
| gr.Markdown("# ποΈ **Chhagan DocVL AI Pro**") | |
| gr.Markdown("### Advanced Document Intelligence Β· Strict JSON Β· MRZ Validation Β· Multilingual Translation") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_choice = gr.Radio( | |
| choices=list(MODEL_CONFIGS.keys()), | |
| label="Select Model", | |
| value="Chhagan-DocVL-Qwen3", | |
| ) | |
| gr.Markdown(""" | |
| - **Chhagan-DocVL-Qwen3** β Best for structured docs | |
| - **Chhagan_ML-VL-OCR-v1** β Best for raw OCR | |
| - **CSM-DocExtract-VL** β Full pipeline + NLLB translation (200+ languages) | |
| """) | |
| image_upload = gr.Image(type="pil", label="Upload Document", height=400) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Scan Document", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Session", variant="secondary") | |
| gr.Markdown("**Session Status:**") | |
| session_display = gr.JSON(label="Current Session Data", value={}) | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Extraction Results") | |
| output_json = gr.JSON(label="Structured Output") | |
| output_text = gr.Code(label="Raw JSON Response", language="json") | |
| submit_btn.click( | |
| fn=process_document, | |
| inputs=[image_upload, state, model_choice], | |
| outputs=[output_text, state, output_json], | |
| ).then( | |
| fn=lambda s: s, | |
| inputs=[state], | |
| outputs=[session_display], | |
| ) | |
| clear_btn.click( | |
| fn=clear_session, | |
| outputs=[state, output_text, output_json], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |