Chhagan005 commited on
Commit
660df5d
Β·
verified Β·
1 Parent(s): 29d1fd9

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +83 -99
app.py CHANGED
@@ -5,166 +5,170 @@ import torch
5
  import gc
6
  from PIL import Image
7
  from transformers import AutoModelForImageTextToText, AutoProcessor
8
- from qwen_vl_utils import process_vision_info
9
  import json
10
  import re
11
  from typing import Dict, List, Any, Optional
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
 
15
- # ──────────────────────────────────────────────────────────────
16
- # 1. Smart Memory Cache (From your reference, heavily optimized)
17
- # ──────────────────────────────────────────────────────────────
18
  _model_cache = {}
19
- MAX_CACHED_MODELS = 2 # Limits RAM usage on free HF Space CPU
 
20
 
21
  def load_model(model_id: str):
22
- # 1. Agar cache me hai, wahi se return karo (0 loading time)
23
  if model_id in _model_cache:
24
- print(f"⚑ Fast Load: {model_id} already in cache!")
25
  return _model_cache[model_id]
26
 
27
- # 2. RAM check (Agar memory full hai, toh sabse purana model nikal do)
28
  if len(_model_cache) >= MAX_CACHED_MODELS:
29
- oldest_model = list(_model_cache.keys())[0]
30
- print(f"🧹 Memory Full! Unloading old model: {oldest_model}")
31
- del _model_cache[oldest_model]
32
  gc.collect()
33
 
34
- # 3. Pehli baar model load karo
35
- print(f"⏳ Loading model into memory: {model_id}")
36
  try:
37
  processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN)
38
- # Check for GPU (from reference)
39
- device_type = "auto" if torch.cuda.is_available() else "cpu"
40
-
41
  model = AutoModelForImageTextToText.from_pretrained(
42
- model_id,
43
- device_map=device_type,
44
- low_cpu_mem_usage=True,
45
- token=HF_TOKEN
46
  )
47
  model.eval()
48
-
49
  _model_cache[model_id] = (processor, model)
50
- print(f"βœ… {model_id} loaded successfully!")
51
  return processor, model
52
  except Exception as e:
53
- print(f"❌ Error loading {model_id}: {str(e)}")
54
  return None, None
55
 
56
  def ui_model_change(model_id):
57
  processor, model = load_model(model_id)
58
- if model:
59
- return f"βœ… Model Active: {model_id} (Cached in Memory)"
60
  return f"❌ Failed to load {model_id}"
61
 
62
- # ──────────────────────────────────────────────────────────────
63
- # 2. Enterprise OCR JSON Parsing (Our logic)
64
- # ──────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def extract_tag(tag, text):
66
  match = re.search(f"<(?:{tag})?>(.*?)</(?:{tag})?", text, re.IGNORECASE)
67
  if not match: match = re.search(f"<{tag}>(.*?)</{tag}>", text, re.IGNORECASE)
68
  return match.group(1).strip() if match else "UNKNOWN"
69
 
70
  def build_enterprise_json(raw_text):
71
- civ_id = extract_tag("ID", raw_text)
72
- name = extract_tag("NAME", raw_text)
73
- dob = extract_tag("DOB", raw_text)
74
- nat = extract_tag("NAT", raw_text)
75
-
76
  result_json = {
77
  "DocumentMetadata": {"document_type": "Resident Card", "has_mrz": True},
78
  "StructuredData": {
79
- "civil_number": civ_id, "full_name": name, "date_of_birth": dob, "nationality": nat
 
 
 
80
  }
81
  }
82
  return json.dumps(result_json, indent=2, ensure_ascii=False)
83
 
84
  def run_document_scan(front_img, model_name):
85
  if front_img is None: return "Error: Please upload document image."
86
-
87
  processor, model = load_model(model_name)
88
  if not model: return "Error: Model not loaded."
89
 
90
  prompt = "Extract details inside these XML tags ONLY:\n<ID></ID>\n<NAME></NAME>\n<DOB></DOB>\n<NAT></NAT>"
91
  messages = [{"role": "user", "content": [{"type": "image", "image": front_img}, {"type": "text", "text": prompt}]}]
92
-
93
- try:
94
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
95
- image_inputs, video_inputs = process_vision_info(messages)
96
- inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
97
- inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
98
 
 
 
99
  with torch.no_grad():
100
  generated_ids = model.generate(**inputs, max_new_tokens=150, temperature=0.1)
101
-
102
  trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
103
  raw_output = processor.batch_decode(trimmed, skip_special_tokens=True)[0]
104
  return build_enterprise_json(raw_output)
105
  except Exception as e:
106
- return f"Extraction Failed: {str(e)}"
107
 
108
- # ──────────────────────────────────────────────────────────────
109
- # 3. Chat Inference (Reference Architecture Logic)
110
- # ──────────────────────────────────────────────────────────────
111
- def process_chat(message: str, image: Optional[Image.Image], history: List[Dict[str, Any]], model_name: str) -> str:
112
  processor, model = load_model(model_name)
113
  if not model: return "Error: Model not loaded."
114
 
 
 
 
 
 
115
  content = []
116
  if image is not None:
117
  content.append({"type": "image", "image": image})
118
- if message:
119
- content.append({"type": "text", "text": message})
120
 
121
- # Prepare pure history dictionary
122
- messages = [{"role": m["role"], "content": m["content"]} for m in history if m.get("role") in ("user", "assistant")]
123
  if content:
124
  messages.append({"role": "user", "content": content})
125
 
126
  try:
127
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
- image_inputs, video_inputs = process_vision_info(messages)
129
-
130
- inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
131
- inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
132
-
133
  with torch.no_grad():
134
  generated_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
135
-
136
  trimmed = [o[len(i):] for i, o in zip(inputs['input_ids'], generated_ids)]
137
- return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
138
  except Exception as e:
139
  return f"❌ Error: {str(e)}"
140
 
141
- # Chat Wrapper handling the UI logic
142
- def chat_fn(message: Dict[str, Any], history: List[Dict[str, Any]], model_name: str):
143
  text = message.get("text", "")
144
  files = message.get("files", [])
145
-
146
  image = None
147
  if files:
148
  try: image = Image.open(files[0]).convert("RGB")
149
- except Exception as e: print(f"Image load error: {e}")
150
 
151
  response = process_chat(text, image, history, model_name)
152
 
153
- # Append to history precisely as dictionaries (Fixes all Gradio 5+ type errors)
154
  display_text = f"{text}\nπŸ“Ž [Image attached]" if image else text
155
  history.append({"role": "user", "content": display_text})
156
  history.append({"role": "assistant", "content": response})
157
-
158
- # Clears the multimodal textbox on send
159
  return gr.update(value={"text": "", "files": []}), history
160
 
161
-
162
- # ──────────────────────────────────────────────────────────────
163
- # 4. Gradio Interface (Unified UI)
164
- # ──────────────────────────────────────────────────────────────
165
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
166
  gr.Markdown("# πŸͺͺ CSM Smart Document Engine")
167
- gr.Markdown("_Unified architecture with On-Demand Caching & Multi-Turn Chat_")
168
 
169
  with gr.Row(variant="panel"):
170
  model_dropdown = gr.Dropdown(
@@ -172,50 +176,30 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
172
  "Chhagan005/CSM-KIE-Universal",
173
  "Chhagan005/CSM-DocExtract-8N",
174
  "Chhagan005/CSM-DocExtract-4N",
175
- "Chhagan005/CSM-DocExtract-2N"
176
  ],
177
- label="πŸ€– Select Model",
178
- value="Chhagan005/CSM-KIE-Universal",
179
- interactive=True
180
  )
181
  status_bar = gr.Textbox(label="Memory Status", value="Select a model to load into memory", interactive=False)
182
-
183
- # Load model dynamically when dropdown changes
184
  model_dropdown.change(fn=ui_model_change, inputs=[model_dropdown], outputs=[status_bar])
185
 
186
  with gr.Tabs():
187
- # TAB 1: Document Scan
188
  with gr.TabItem("πŸ“„ Document Scanner"):
189
  with gr.Row():
190
  with gr.Column():
191
- doc_img = gr.Image(type="pil", label="Upload ID Card")
192
  scan_btn = gr.Button("πŸ” Extract JSON", variant="primary")
193
  with gr.Column():
194
  json_output = gr.Code(language="json", label="Enterprise Result")
195
  scan_btn.click(fn=run_document_scan, inputs=[doc_img, model_dropdown], outputs=[json_output])
196
 
197
- # TAB 2: Multimodal Chat
198
  with gr.TabItem("πŸ’¬ Intelligent Chat"):
199
- gr.Markdown("**Tips:** Upload an image using the + icon inside the chatbox.")
200
- with gr.Row():
201
- with gr.Column(scale=1):
202
- # Pure Gradio Chatbot (No type=tuples needed since we pass strict dicts now)
203
- chatbot = gr.Chatbot(label="Chat History", height=450, value=[])
204
- # Multimodal box exactly like your reference
205
- chat_msg = gr.MultimodalTextbox(
206
- label="Message",
207
- placeholder="Type a message or click πŸ“Ž to upload an image...",
208
- file_types=["image"],
209
- submit_btn=True
210
- )
211
-
212
- # Submitting the Multimodal Box
213
- chat_msg.submit(
214
- fn=chat_fn,
215
- inputs=[chat_msg, chatbot, model_dropdown],
216
- outputs=[chat_msg, chatbot]
217
  )
 
218
 
219
- # Kickoff initialization
220
  if __name__ == "__main__":
221
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
5
  import gc
6
  from PIL import Image
7
  from transformers import AutoModelForImageTextToText, AutoProcessor
 
8
  import json
9
  import re
10
  from typing import Dict, List, Any, Optional
11
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
 
14
+ # ── Model Cache ──────────────────────────────────────────────
 
 
15
  _model_cache = {}
16
+ MAX_CACHED_MODELS = 2
17
+ QWEN_VL_IMG_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
18
 
19
  def load_model(model_id: str):
 
20
  if model_id in _model_cache:
21
+ print(f"⚑ Cache Hit: {model_id}")
22
  return _model_cache[model_id]
23
 
 
24
  if len(_model_cache) >= MAX_CACHED_MODELS:
25
+ oldest = list(_model_cache.keys())[0]
26
+ print(f"🧹 Unloading: {oldest}")
27
+ del _model_cache[oldest]
28
  gc.collect()
29
 
30
+ print(f"⏳ Loading: {model_id}")
 
31
  try:
32
  processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN)
33
+ device_map = "auto" if torch.cuda.is_available() else "cpu"
 
 
34
  model = AutoModelForImageTextToText.from_pretrained(
35
+ model_id, device_map=device_map, low_cpu_mem_usage=True, token=HF_TOKEN
 
 
 
36
  )
37
  model.eval()
 
38
  _model_cache[model_id] = (processor, model)
39
+ print(f"βœ… Loaded: {model_id}")
40
  return processor, model
41
  except Exception as e:
 
42
  return None, None
43
 
44
  def ui_model_change(model_id):
45
  processor, model = load_model(model_id)
46
+ if model: return f"βœ… Model Active: {model_id}"
 
47
  return f"❌ Failed to load {model_id}"
48
 
49
+ # ── THE FIX: prepare_inputs (from your reference app.py) ──────
50
+ # Yeh function mixed content (string + list) ko flat format me
51
+ # convert karke processor ko safe tarike se deta hai
52
+ def prepare_inputs(processor, model, messages: List[Dict]) -> Dict:
53
+ pil_images = []
54
+ flat_messages = []
55
+
56
+ for msg in messages:
57
+ role = msg.get("role", "user")
58
+ content = msg.get("content", "")
59
+
60
+ if isinstance(content, list):
61
+ parts = []
62
+ for item in content:
63
+ if not isinstance(item, dict):
64
+ parts.append(str(item))
65
+ continue
66
+ t = item.get("type", "")
67
+ if t == "text":
68
+ parts.append(item.get("text", ""))
69
+ elif t == "image":
70
+ img = item.get("image")
71
+ if img is not None and isinstance(img, Image.Image):
72
+ pil_images.append(img)
73
+ parts.append(QWEN_VL_IMG_TOKEN)
74
+ flat_messages.append({"role": role, "content": "".join(parts)})
75
+ else:
76
+ # History string messages directly add kar do
77
+ flat_messages.append({"role": role, "content": str(content)})
78
+
79
+ text = processor.apply_chat_template(flat_messages, tokenize=False, add_generation_prompt=True)
80
+
81
+ if pil_images and hasattr(processor, "image_processor"):
82
+ inputs = processor(text=[text], images=pil_images, padding=True, return_tensors="pt")
83
+ else:
84
+ inputs = processor(text=[text], padding=True, return_tensors="pt")
85
+
86
+ return {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
87
+
88
+ # ── Enterprise OCR ────────────────────────────────────────────
89
  def extract_tag(tag, text):
90
  match = re.search(f"<(?:{tag})?>(.*?)</(?:{tag})?", text, re.IGNORECASE)
91
  if not match: match = re.search(f"<{tag}>(.*?)</{tag}>", text, re.IGNORECASE)
92
  return match.group(1).strip() if match else "UNKNOWN"
93
 
94
  def build_enterprise_json(raw_text):
 
 
 
 
 
95
  result_json = {
96
  "DocumentMetadata": {"document_type": "Resident Card", "has_mrz": True},
97
  "StructuredData": {
98
+ "civil_number": extract_tag("ID", raw_text),
99
+ "full_name": extract_tag("NAME", raw_text),
100
+ "date_of_birth": extract_tag("DOB", raw_text),
101
+ "nationality": extract_tag("NAT", raw_text)
102
  }
103
  }
104
  return json.dumps(result_json, indent=2, ensure_ascii=False)
105
 
106
  def run_document_scan(front_img, model_name):
107
  if front_img is None: return "Error: Please upload document image."
 
108
  processor, model = load_model(model_name)
109
  if not model: return "Error: Model not loaded."
110
 
111
  prompt = "Extract details inside these XML tags ONLY:\n<ID></ID>\n<NAME></NAME>\n<DOB></DOB>\n<NAT></NAT>"
112
  messages = [{"role": "user", "content": [{"type": "image", "image": front_img}, {"type": "text", "text": prompt}]}]
 
 
 
 
 
 
113
 
114
+ try:
115
+ inputs = prepare_inputs(processor, model, messages)
116
  with torch.no_grad():
117
  generated_ids = model.generate(**inputs, max_new_tokens=150, temperature=0.1)
 
118
  trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
119
  raw_output = processor.batch_decode(trimmed, skip_special_tokens=True)[0]
120
  return build_enterprise_json(raw_output)
121
  except Exception as e:
122
+ return f"Extraction Failed: {str(e)}"
123
 
124
+ # ── Chat ──────────────────────────────────────────────────────
125
+ def process_chat(text: str, image: Optional[Image.Image], history: List[Dict], model_name: str) -> str:
 
 
126
  processor, model = load_model(model_name)
127
  if not model: return "Error: Model not loaded."
128
 
129
+ # Build history messages first
130
+ messages = [{"role": m["role"], "content": m["content"]}
131
+ for m in history if m.get("role") in ("user", "assistant")]
132
+
133
+ # Current message with optional image (as list)
134
  content = []
135
  if image is not None:
136
  content.append({"type": "image", "image": image})
137
+ if text:
138
+ content.append({"type": "text", "text": text})
139
 
 
 
140
  if content:
141
  messages.append({"role": "user", "content": content})
142
 
143
  try:
144
+ # prepare_inputs now handles mixed string/list content safely
145
+ inputs = prepare_inputs(processor, model, messages)
 
 
 
 
146
  with torch.no_grad():
147
  generated_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
 
148
  trimmed = [o[len(i):] for i, o in zip(inputs['input_ids'], generated_ids)]
149
+ return processor.batch_decode(trimmed, skip_special_tokens=True)[0]
150
  except Exception as e:
151
  return f"❌ Error: {str(e)}"
152
 
153
+ def chat_fn(message: Dict[str, Any], history: List[Dict], model_name: str):
 
154
  text = message.get("text", "")
155
  files = message.get("files", [])
 
156
  image = None
157
  if files:
158
  try: image = Image.open(files[0]).convert("RGB")
159
+ except Exception as e: print(f"Image error: {e}")
160
 
161
  response = process_chat(text, image, history, model_name)
162
 
 
163
  display_text = f"{text}\nπŸ“Ž [Image attached]" if image else text
164
  history.append({"role": "user", "content": display_text})
165
  history.append({"role": "assistant", "content": response})
 
 
166
  return gr.update(value={"text": "", "files": []}), history
167
 
168
+ # ── Gradio UI ─────────────────────────────────────────────────
 
 
 
169
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
  gr.Markdown("# πŸͺͺ CSM Smart Document Engine")
171
+ gr.Markdown("_On-Demand Caching β€’ Document Scanner β€’ Intelligent Multi-Turn Chat_")
172
 
173
  with gr.Row(variant="panel"):
174
  model_dropdown = gr.Dropdown(
 
176
  "Chhagan005/CSM-KIE-Universal",
177
  "Chhagan005/CSM-DocExtract-8N",
178
  "Chhagan005/CSM-DocExtract-4N",
 
179
  ],
180
+ label="πŸ€– Select Model", value="Chhagan005/CSM-KIE-Universal", interactive=True
 
 
181
  )
182
  status_bar = gr.Textbox(label="Memory Status", value="Select a model to load into memory", interactive=False)
183
+
 
184
  model_dropdown.change(fn=ui_model_change, inputs=[model_dropdown], outputs=[status_bar])
185
 
186
  with gr.Tabs():
 
187
  with gr.TabItem("πŸ“„ Document Scanner"):
188
  with gr.Row():
189
  with gr.Column():
190
+ doc_img = gr.Image(type="pil", label="Upload ID Card")
191
  scan_btn = gr.Button("πŸ” Extract JSON", variant="primary")
192
  with gr.Column():
193
  json_output = gr.Code(language="json", label="Enterprise Result")
194
  scan_btn.click(fn=run_document_scan, inputs=[doc_img, model_dropdown], outputs=[json_output])
195
 
 
196
  with gr.TabItem("πŸ’¬ Intelligent Chat"):
197
+ chatbot = gr.Chatbot(label="Chat History", height=450, value=[])
198
+ chat_msg = gr.MultimodalTextbox(
199
+ label="Message", placeholder="Type a message or click πŸ“Ž to attach an image...",
200
+ file_types=["image"], submit_btn=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  )
202
+ chat_msg.submit(fn=chat_fn, inputs=[chat_msg, chatbot, model_dropdown], outputs=[chat_msg, chatbot])
203
 
 
204
  if __name__ == "__main__":
205
  demo.launch(server_name="0.0.0.0", server_port=7860)