Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import torch | |
| import gc | |
| from PIL import Image | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| import json | |
| import re | |
| from typing import Dict, List, Any, Optional | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ββ Model Cache ββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model_cache = {} | |
| MAX_CACHED_MODELS = 2 | |
| QWEN_VL_IMG_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" | |
| def load_model(model_id: str): | |
| if model_id in _model_cache: | |
| print(f"β‘ Cache Hit: {model_id}") | |
| return _model_cache[model_id] | |
| if len(_model_cache) >= MAX_CACHED_MODELS: | |
| oldest = list(_model_cache.keys())[0] | |
| print(f"π§Ή Unloading: {oldest}") | |
| del _model_cache[oldest] | |
| gc.collect() | |
| print(f"β³ Loading: {model_id}") | |
| try: | |
| processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN) | |
| device_map = "auto" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, device_map=device_map, low_cpu_mem_usage=True, token=HF_TOKEN | |
| ) | |
| model.eval() | |
| _model_cache[model_id] = (processor, model) | |
| print(f"β Loaded: {model_id}") | |
| return processor, model | |
| except Exception as e: | |
| return None, None | |
| def ui_model_change(model_id): | |
| processor, model = load_model(model_id) | |
| if model: return f"β Model Active: {model_id}" | |
| return f"β Failed to load {model_id}" | |
| # ββ THE FIX: prepare_inputs (from your reference app.py) ββββββ | |
| # Yeh function mixed content (string + list) ko flat format me | |
| # convert karke processor ko safe tarike se deta hai | |
| def prepare_inputs(processor, model, messages: List[Dict]) -> Dict: | |
| pil_images = [] | |
| flat_messages = [] | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if isinstance(content, list): | |
| parts = [] | |
| for item in content: | |
| if not isinstance(item, dict): | |
| parts.append(str(item)) | |
| continue | |
| t = item.get("type", "") | |
| if t == "text": | |
| parts.append(item.get("text", "")) | |
| elif t == "image": | |
| img = item.get("image") | |
| if img is not None and isinstance(img, Image.Image): | |
| pil_images.append(img) | |
| parts.append(QWEN_VL_IMG_TOKEN) | |
| flat_messages.append({"role": role, "content": "".join(parts)}) | |
| else: | |
| # History string messages directly add kar do | |
| flat_messages.append({"role": role, "content": str(content)}) | |
| text = processor.apply_chat_template(flat_messages, tokenize=False, add_generation_prompt=True) | |
| if pil_images and hasattr(processor, "image_processor"): | |
| inputs = processor(text=[text], images=pil_images, padding=True, return_tensors="pt") | |
| else: | |
| inputs = processor(text=[text], padding=True, return_tensors="pt") | |
| return {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()} | |
| # ββ Enterprise OCR ββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_tag(tag, text): | |
| match = re.search(f"<(?:{tag})?>(.*?)</(?:{tag})?", text, re.IGNORECASE) | |
| if not match: match = re.search(f"<{tag}>(.*?)</{tag}>", text, re.IGNORECASE) | |
| return match.group(1).strip() if match else "UNKNOWN" | |
| def build_enterprise_json(raw_text): | |
| result_json = { | |
| "DocumentMetadata": {"document_type": "Resident Card", "has_mrz": True}, | |
| "StructuredData": { | |
| "civil_number": extract_tag("ID", raw_text), | |
| "full_name": extract_tag("NAME", raw_text), | |
| "date_of_birth": extract_tag("DOB", raw_text), | |
| "nationality": extract_tag("NAT", raw_text) | |
| } | |
| } | |
| return json.dumps(result_json, indent=2, ensure_ascii=False) | |
| def run_document_scan(front_img, model_name): | |
| if front_img is None: return "Error: Please upload document image." | |
| processor, model = load_model(model_name) | |
| if not model: return "Error: Model not loaded." | |
| prompt = "Extract details inside these XML tags ONLY:\n<ID></ID>\n<NAME></NAME>\n<DOB></DOB>\n<NAT></NAT>" | |
| messages = [{"role": "user", "content": [{"type": "image", "image": front_img}, {"type": "text", "text": prompt}]}] | |
| try: | |
| inputs = prepare_inputs(processor, model, messages) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=150, temperature=0.1) | |
| trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)] | |
| raw_output = processor.batch_decode(trimmed, skip_special_tokens=True)[0] | |
| return build_enterprise_json(raw_output) | |
| except Exception as e: | |
| return f"Extraction Failed: {str(e)}" | |
| # ββ Chat ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_chat(text: str, image: Optional[Image.Image], history: List[Dict], model_name: str) -> str: | |
| processor, model = load_model(model_name) | |
| if not model: return "Error: Model not loaded." | |
| # Build history messages first | |
| messages = [{"role": m["role"], "content": m["content"]} | |
| for m in history if m.get("role") in ("user", "assistant")] | |
| # Current message with optional image (as list) | |
| content = [] | |
| if image is not None: | |
| content.append({"type": "image", "image": image}) | |
| if text: | |
| content.append({"type": "text", "text": text}) | |
| if content: | |
| messages.append({"role": "user", "content": content}) | |
| try: | |
| # prepare_inputs now handles mixed string/list content safely | |
| inputs = prepare_inputs(processor, model, messages) | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9) | |
| trimmed = [o[len(i):] for i, o in zip(inputs['input_ids'], generated_ids)] | |
| return processor.batch_decode(trimmed, skip_special_tokens=True)[0] | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| def chat_fn(message: Dict[str, Any], history: List[Dict], model_name: str): | |
| text = message.get("text", "") | |
| files = message.get("files", []) | |
| image = None | |
| if files: | |
| try: image = Image.open(files[0]).convert("RGB") | |
| except Exception as e: print(f"Image error: {e}") | |
| response = process_chat(text, image, history, model_name) | |
| display_text = f"{text}\nπ [Image attached]" if image else text | |
| history.append({"role": "user", "content": display_text}) | |
| history.append({"role": "assistant", "content": response}) | |
| return gr.update(value={"text": "", "files": []}), history | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# πͺͺ CSM Smart Document Engine") | |
| gr.Markdown("_On-Demand Caching β’ Document Scanner β’ Intelligent Multi-Turn Chat_") | |
| with gr.Row(variant="panel"): | |
| model_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Chhagan005/CSM-KIE-Universal", | |
| "Chhagan005/CSM-DocExtract-8N", | |
| "Chhagan005/CSM-DocExtract-4N", | |
| ], | |
| label="π€ Select Model", value="Chhagan005/CSM-KIE-Universal", interactive=True | |
| ) | |
| status_bar = gr.Textbox(label="Memory Status", value="Select a model to load into memory", interactive=False) | |
| model_dropdown.change(fn=ui_model_change, inputs=[model_dropdown], outputs=[status_bar]) | |
| with gr.Tabs(): | |
| with gr.TabItem("π Document Scanner"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| doc_img = gr.Image(type="pil", label="Upload ID Card") | |
| scan_btn = gr.Button("π Extract JSON", variant="primary") | |
| with gr.Column(): | |
| json_output = gr.Code(language="json", label="Enterprise Result") | |
| scan_btn.click(fn=run_document_scan, inputs=[doc_img, model_dropdown], outputs=[json_output]) | |
| with gr.TabItem("π¬ Intelligent Chat"): | |
| chatbot = gr.Chatbot(label="Chat History", height=450, value=[]) | |
| chat_msg = gr.MultimodalTextbox( | |
| label="Message", placeholder="Type a message or click π to attach an image...", | |
| file_types=["image"], submit_btn=True | |
| ) | |
| chat_msg.submit(fn=chat_fn, inputs=[chat_msg, chatbot, model_dropdown], outputs=[chat_msg, chatbot]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |