Spaces:
Sleeping
Sleeping
File size: 8,999 Bytes
0e9f127 ebe0135 0e9f127 45c7be0 9e25965 45c7be0 0e9f127 701a46b 9e25965 0e9f127 45c7be0 0e9f127 660df5d 9e25965 660df5d 9e25965 660df5d 9e25965 660df5d 9e25965 660df5d 45c7be0 9e25965 660df5d 45c7be0 660df5d 0e9f127 45c7be0 9e25965 660df5d 9e25965 45c7be0 9e25965 0e9f127 9e25965 660df5d 9e25965 660df5d 701a46b 9e25965 701a46b 45c7be0 701a46b 45c7be0 701a46b 660df5d 701a46b 9e25965 45c7be0 660df5d 45c7be0 9e25965 701a46b 660df5d 45c7be0 660df5d 9e25965 660df5d 9e25965 660df5d 45c7be0 9e25965 660df5d 9e25965 660df5d 9e25965 660df5d 9e25965 660df5d 9e25965 660df5d 9e25965 660df5d 9e25965 45c7be0 660df5d 45c7be0 9e25965 660df5d 9e25965 45c7be0 9e25965 45c7be0 660df5d 9e25965 45c7be0 9e25965 660df5d 9e25965 660df5d 0e9f127 93307ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import os
import gradio as gr
import torch
import gc
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
import json
import re
from typing import Dict, List, Any, Optional
HF_TOKEN = os.environ.get("HF_TOKEN")
# ββ Model Cache ββββββββββββββββββββββββββββββββββββββββββββββ
_model_cache = {}
MAX_CACHED_MODELS = 2
QWEN_VL_IMG_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
def load_model(model_id: str):
if model_id in _model_cache:
print(f"β‘ Cache Hit: {model_id}")
return _model_cache[model_id]
if len(_model_cache) >= MAX_CACHED_MODELS:
oldest = list(_model_cache.keys())[0]
print(f"π§Ή Unloading: {oldest}")
del _model_cache[oldest]
gc.collect()
print(f"β³ Loading: {model_id}")
try:
processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN)
device_map = "auto" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageTextToText.from_pretrained(
model_id, device_map=device_map, low_cpu_mem_usage=True, token=HF_TOKEN
)
model.eval()
_model_cache[model_id] = (processor, model)
print(f"β
Loaded: {model_id}")
return processor, model
except Exception as e:
return None, None
def ui_model_change(model_id):
processor, model = load_model(model_id)
if model: return f"β
Model Active: {model_id}"
return f"β Failed to load {model_id}"
# ββ THE FIX: prepare_inputs (from your reference app.py) ββββββ
# Yeh function mixed content (string + list) ko flat format me
# convert karke processor ko safe tarike se deta hai
def prepare_inputs(processor, model, messages: List[Dict]) -> Dict:
pil_images = []
flat_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if isinstance(content, list):
parts = []
for item in content:
if not isinstance(item, dict):
parts.append(str(item))
continue
t = item.get("type", "")
if t == "text":
parts.append(item.get("text", ""))
elif t == "image":
img = item.get("image")
if img is not None and isinstance(img, Image.Image):
pil_images.append(img)
parts.append(QWEN_VL_IMG_TOKEN)
flat_messages.append({"role": role, "content": "".join(parts)})
else:
# History string messages directly add kar do
flat_messages.append({"role": role, "content": str(content)})
text = processor.apply_chat_template(flat_messages, tokenize=False, add_generation_prompt=True)
if pil_images and hasattr(processor, "image_processor"):
inputs = processor(text=[text], images=pil_images, padding=True, return_tensors="pt")
else:
inputs = processor(text=[text], padding=True, return_tensors="pt")
return {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
# ββ Enterprise OCR ββββββββββββββββββββββββββββββββββββββββββββ
def extract_tag(tag, text):
match = re.search(f"<(?:{tag})?>(.*?)</(?:{tag})?", text, re.IGNORECASE)
if not match: match = re.search(f"<{tag}>(.*?)</{tag}>", text, re.IGNORECASE)
return match.group(1).strip() if match else "UNKNOWN"
def build_enterprise_json(raw_text):
result_json = {
"DocumentMetadata": {"document_type": "Resident Card", "has_mrz": True},
"StructuredData": {
"civil_number": extract_tag("ID", raw_text),
"full_name": extract_tag("NAME", raw_text),
"date_of_birth": extract_tag("DOB", raw_text),
"nationality": extract_tag("NAT", raw_text)
}
}
return json.dumps(result_json, indent=2, ensure_ascii=False)
def run_document_scan(front_img, model_name):
if front_img is None: return "Error: Please upload document image."
processor, model = load_model(model_name)
if not model: return "Error: Model not loaded."
prompt = "Extract details inside these XML tags ONLY:\n<ID></ID>\n<NAME></NAME>\n<DOB></DOB>\n<NAT></NAT>"
messages = [{"role": "user", "content": [{"type": "image", "image": front_img}, {"type": "text", "text": prompt}]}]
try:
inputs = prepare_inputs(processor, model, messages)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=150, temperature=0.1)
trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
raw_output = processor.batch_decode(trimmed, skip_special_tokens=True)[0]
return build_enterprise_json(raw_output)
except Exception as e:
return f"Extraction Failed: {str(e)}"
# ββ Chat ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def process_chat(text: str, image: Optional[Image.Image], history: List[Dict], model_name: str) -> str:
processor, model = load_model(model_name)
if not model: return "Error: Model not loaded."
# Build history messages first
messages = [{"role": m["role"], "content": m["content"]}
for m in history if m.get("role") in ("user", "assistant")]
# Current message with optional image (as list)
content = []
if image is not None:
content.append({"type": "image", "image": image})
if text:
content.append({"type": "text", "text": text})
if content:
messages.append({"role": "user", "content": content})
try:
# prepare_inputs now handles mixed string/list content safely
inputs = prepare_inputs(processor, model, messages)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
trimmed = [o[len(i):] for i, o in zip(inputs['input_ids'], generated_ids)]
return processor.batch_decode(trimmed, skip_special_tokens=True)[0]
except Exception as e:
return f"β Error: {str(e)}"
def chat_fn(message: Dict[str, Any], history: List[Dict], model_name: str):
text = message.get("text", "")
files = message.get("files", [])
image = None
if files:
try: image = Image.open(files[0]).convert("RGB")
except Exception as e: print(f"Image error: {e}")
response = process_chat(text, image, history, model_name)
display_text = f"{text}\nπ [Image attached]" if image else text
history.append({"role": "user", "content": display_text})
history.append({"role": "assistant", "content": response})
return gr.update(value={"text": "", "files": []}), history
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πͺͺ CSM Smart Document Engine")
gr.Markdown("_On-Demand Caching β’ Document Scanner β’ Intelligent Multi-Turn Chat_")
with gr.Row(variant="panel"):
model_dropdown = gr.Dropdown(
choices=[
"Chhagan005/CSM-KIE-Universal",
"Chhagan005/CSM-DocExtract-8N",
"Chhagan005/CSM-DocExtract-4N",
],
label="π€ Select Model", value="Chhagan005/CSM-KIE-Universal", interactive=True
)
status_bar = gr.Textbox(label="Memory Status", value="Select a model to load into memory", interactive=False)
model_dropdown.change(fn=ui_model_change, inputs=[model_dropdown], outputs=[status_bar])
with gr.Tabs():
with gr.TabItem("π Document Scanner"):
with gr.Row():
with gr.Column():
doc_img = gr.Image(type="pil", label="Upload ID Card")
scan_btn = gr.Button("π Extract JSON", variant="primary")
with gr.Column():
json_output = gr.Code(language="json", label="Enterprise Result")
scan_btn.click(fn=run_document_scan, inputs=[doc_img, model_dropdown], outputs=[json_output])
with gr.TabItem("π¬ Intelligent Chat"):
chatbot = gr.Chatbot(label="Chat History", height=450, value=[])
chat_msg = gr.MultimodalTextbox(
label="Message", placeholder="Type a message or click π to attach an image...",
file_types=["image"], submit_btn=True
)
chat_msg.submit(fn=chat_fn, inputs=[chat_msg, chatbot, model_dropdown], outputs=[chat_msg, chatbot])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|