Chhagan005 commited on
Commit
9e25965
Β·
verified Β·
1 Parent(s): 6442ebb

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +149 -72
app.py CHANGED
@@ -3,48 +3,68 @@ import os
3
  import gradio as gr
4
  import torch
5
  import gc
 
6
  from transformers import AutoModelForImageTextToText, AutoProcessor
7
  from qwen_vl_utils import process_vision_info
8
  import json
9
  import re
10
-
11
- # Global State
12
- current_model_id = None
13
- model = None
14
- processor = None
15
 
16
  HF_TOKEN = os.environ.get("HF_TOKEN")
17
 
18
- def load_selected_model(repo_id):
19
- global model, processor, current_model_id
20
- if repo_id == current_model_id and model is not None:
21
- return f"βœ… Model {repo_id} is already active."
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  try:
24
- if model is not None:
25
- del model
26
- del processor
27
- gc.collect()
28
-
29
- print(f"Loading {repo_id}...")
30
- processor = AutoProcessor.from_pretrained(repo_id, token=HF_TOKEN)
31
 
32
  model = AutoModelForImageTextToText.from_pretrained(
33
- repo_id,
34
- device_map="cpu",
35
  low_cpu_mem_usage=True,
36
  token=HF_TOKEN
37
  )
38
  model.eval()
39
- current_model_id = repo_id
40
- return f"πŸš€ Successfully Loaded: {repo_id}"
 
 
41
  except Exception as e:
42
- return f"❌ Error loading {repo_id}: {str(e)}"
 
43
 
 
 
 
 
 
 
 
 
 
44
  def extract_tag(tag, text):
45
  match = re.search(f"<(?:{tag})?>(.*?)</(?:{tag})?", text, re.IGNORECASE)
46
- if not match:
47
- match = re.search(f"<{tag}>(.*?)</{tag}>", text, re.IGNORECASE)
48
  return match.group(1).strip() if match else "UNKNOWN"
49
 
50
  def build_enterprise_json(raw_text):
@@ -61,36 +81,91 @@ def build_enterprise_json(raw_text):
61
  }
62
  return json.dumps(result_json, indent=2, ensure_ascii=False)
63
 
64
- def run_qwen(image, prompt_text, max_tokens=150):
65
- if model is None:
66
- return "Error: Please load a model from the dropdown first."
67
- if image is None:
68
- return "Error: Image required."
 
 
 
69
 
70
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt_text}]}]
71
  try:
72
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
73
  image_inputs, video_inputs = process_vision_info(messages)
74
-
75
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
76
- inputs = {k: v.to("cpu") for k, v in inputs.items() if isinstance(v, torch.Tensor)}
77
 
78
  with torch.no_grad():
79
- generated_ids = model.generate(**inputs, max_new_tokens=max_tokens, temperature=0.1)
80
 
81
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
82
- return processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
83
  except Exception as e:
84
  return f"Extraction Failed: {str(e)}"
85
 
86
- def ocr_extraction(front_img):
87
- prompt = "Extract details inside these XML tags ONLY:\n<ID></ID>\n<NAME></NAME>\n<DOB></DOB>\n<NAT></NAT>"
88
- raw_output = run_qwen(front_img, prompt, max_tokens=150)
89
- return build_enterprise_json(raw_output)
 
 
 
 
 
 
 
 
90
 
91
- with gr.Blocks() as demo:
92
- gr.Markdown("# πŸͺͺ CSM Universal Model Testing Playground")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  with gr.Row(variant="panel"):
95
  model_dropdown = gr.Dropdown(
96
  choices=[
@@ -99,46 +174,48 @@ with gr.Blocks() as demo:
99
  "Chhagan005/CSM-DocExtract-4N",
100
  "Chhagan005/CSM-DocExtract-2N"
101
  ],
102
- label="1. Select Model Version",
103
- value="Chhagan005/CSM-KIE-Universal"
 
104
  )
105
- load_btn = gr.Button("πŸ”„ Load Model")
106
- load_status = gr.Textbox(label="Status", interactive=False)
107
-
108
- load_btn.click(load_selected_model, inputs=[model_dropdown], outputs=[load_status])
109
 
110
  with gr.Tabs():
111
- with gr.Tab("Enterprise OCR Check"):
 
112
  with gr.Row():
113
  with gr.Column():
114
- img_input = gr.Image(type="pil", label="Upload Document")
115
- extract_btn = gr.Button("πŸ” Test Extraction", variant="primary")
116
  with gr.Column():
117
- json_output = gr.Code(language="json", label="JSON Output")
118
- extract_btn.click(ocr_extraction, inputs=[img_input], outputs=[json_output])
119
-
120
- with gr.Tab("Document Chat Check"):
 
 
121
  with gr.Row():
122
  with gr.Column(scale=1):
123
- chat_img_input = gr.Image(type="pil", label="Document Attachment")
124
- with gr.Column(scale=2):
125
- # ✨ THE FIX: We explicitly define type="tuples" so Gradio accepts our format without crashing
126
- chatbot = gr.Chatbot(label="Chat Interface", height=400, type="tuples")
127
- with gr.Row():
128
- chat_input = gr.Textbox(placeholder="Ask anything...", show_label=False)
129
- send_btn = gr.Button("Send")
 
 
130
 
131
- def chat_wrapper(image, user_message, chat_history):
132
- # Ensure chat_history is a list
133
- if chat_history is None:
134
- chat_history = []
135
-
136
- ai_response = run_qwen(image, user_message, max_tokens=200)
137
- # Appending as a tuple (user_message, ai_response) which matches type="tuples"
138
- chat_history.append((user_message, ai_response))
139
- return "", chat_history
140
-
141
- send_btn.click(chat_wrapper, inputs=[chat_img_input, chat_input, chatbot], outputs=[chat_input, chatbot])
142
 
 
143
  if __name__ == "__main__":
144
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import gradio as gr
4
  import torch
5
  import gc
6
+ from PIL import Image
7
  from transformers import AutoModelForImageTextToText, AutoProcessor
8
  from qwen_vl_utils import process_vision_info
9
  import json
10
  import re
11
+ from typing import Dict, List, Any, Optional
 
 
 
 
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
 
15
+ # ──────────────────────────────────────────────────────────────
16
+ # 1. Smart Memory Cache (From your reference, heavily optimized)
17
+ # ──────────────────────────────────────────────────────────────
18
+ _model_cache = {}
19
+ MAX_CACHED_MODELS = 2 # Limits RAM usage on free HF Space CPU
20
+
21
+ def load_model(model_id: str):
22
+ # 1. Agar cache me hai, wahi se return karo (0 loading time)
23
+ if model_id in _model_cache:
24
+ print(f"⚑ Fast Load: {model_id} already in cache!")
25
+ return _model_cache[model_id]
26
+
27
+ # 2. RAM check (Agar memory full hai, toh sabse purana model nikal do)
28
+ if len(_model_cache) >= MAX_CACHED_MODELS:
29
+ oldest_model = list(_model_cache.keys())[0]
30
+ print(f"🧹 Memory Full! Unloading old model: {oldest_model}")
31
+ del _model_cache[oldest_model]
32
+ gc.collect()
33
+
34
+ # 3. Pehli baar model load karo
35
+ print(f"⏳ Loading model into memory: {model_id}")
36
  try:
37
+ processor = AutoProcessor.from_pretrained(model_id, token=HF_TOKEN)
38
+ # Check for GPU (from reference)
39
+ device_type = "auto" if torch.cuda.is_available() else "cpu"
 
 
 
 
40
 
41
  model = AutoModelForImageTextToText.from_pretrained(
42
+ model_id,
43
+ device_map=device_type,
44
  low_cpu_mem_usage=True,
45
  token=HF_TOKEN
46
  )
47
  model.eval()
48
+
49
+ _model_cache[model_id] = (processor, model)
50
+ print(f"βœ… {model_id} loaded successfully!")
51
+ return processor, model
52
  except Exception as e:
53
+ print(f"❌ Error loading {model_id}: {str(e)}")
54
+ return None, None
55
 
56
+ def ui_model_change(model_id):
57
+ processor, model = load_model(model_id)
58
+ if model:
59
+ return f"βœ… Model Active: {model_id} (Cached in Memory)"
60
+ return f"❌ Failed to load {model_id}"
61
+
62
+ # ──────────────────────────────────────────────────────────────
63
+ # 2. Enterprise OCR JSON Parsing (Our logic)
64
+ # ──────────────────────────────────────────────────────────────
65
  def extract_tag(tag, text):
66
  match = re.search(f"<(?:{tag})?>(.*?)</(?:{tag})?", text, re.IGNORECASE)
67
+ if not match: match = re.search(f"<{tag}>(.*?)</{tag}>", text, re.IGNORECASE)
 
68
  return match.group(1).strip() if match else "UNKNOWN"
69
 
70
  def build_enterprise_json(raw_text):
 
81
  }
82
  return json.dumps(result_json, indent=2, ensure_ascii=False)
83
 
84
+ def run_document_scan(front_img, model_name):
85
+ if front_img is None: return "Error: Please upload document image."
86
+
87
+ processor, model = load_model(model_name)
88
+ if not model: return "Error: Model not loaded."
89
+
90
+ prompt = "Extract details inside these XML tags ONLY:\n<ID></ID>\n<NAME></NAME>\n<DOB></DOB>\n<NAT></NAT>"
91
+ messages = [{"role": "user", "content": [{"type": "image", "image": front_img}, {"type": "text", "text": prompt}]}]
92
 
 
93
  try:
94
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
95
  image_inputs, video_inputs = process_vision_info(messages)
 
96
  inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
97
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
98
 
99
  with torch.no_grad():
100
+ generated_ids = model.generate(**inputs, max_new_tokens=150, temperature=0.1)
101
 
102
+ trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)]
103
+ raw_output = processor.batch_decode(trimmed, skip_special_tokens=True)[0]
104
+ return build_enterprise_json(raw_output)
105
  except Exception as e:
106
  return f"Extraction Failed: {str(e)}"
107
 
108
+ # ──────────────────────────────────────────────────────────────
109
+ # 3. Chat Inference (Reference Architecture Logic)
110
+ # ──────────────────────────────────────────────────────────────
111
+ def process_chat(message: str, image: Optional[Image.Image], history: List[Dict[str, Any]], model_name: str) -> str:
112
+ processor, model = load_model(model_name)
113
+ if not model: return "Error: Model not loaded."
114
+
115
+ content = []
116
+ if image is not None:
117
+ content.append({"type": "image", "image": image})
118
+ if message:
119
+ content.append({"type": "text", "text": message})
120
 
121
+ # Prepare pure history dictionary
122
+ messages = [{"role": m["role"], "content": m["content"]} for m in history if m.get("role") in ("user", "assistant")]
123
+ if content:
124
+ messages.append({"role": "user", "content": content})
125
+
126
+ try:
127
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
+ image_inputs, video_inputs = process_vision_info(messages)
129
+
130
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
131
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
132
+
133
+ with torch.no_grad():
134
+ generated_ids = model.generate(**inputs, max_new_tokens=512, temperature=0.7, top_p=0.9)
135
+
136
+ trimmed = [o[len(i):] for i, o in zip(inputs['input_ids'], generated_ids)]
137
+ return processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
138
+ except Exception as e:
139
+ return f"❌ Error: {str(e)}"
140
+
141
+ # Chat Wrapper handling the UI logic
142
+ def chat_fn(message: Dict[str, Any], history: List[Dict[str, Any]], model_name: str):
143
+ text = message.get("text", "")
144
+ files = message.get("files", [])
145
 
146
+ image = None
147
+ if files:
148
+ try: image = Image.open(files[0]).convert("RGB")
149
+ except Exception as e: print(f"Image load error: {e}")
150
+
151
+ response = process_chat(text, image, history, model_name)
152
+
153
+ # Append to history precisely as dictionaries (Fixes all Gradio 5+ type errors)
154
+ display_text = f"{text}\nπŸ“Ž [Image attached]" if image else text
155
+ history.append({"role": "user", "content": display_text})
156
+ history.append({"role": "assistant", "content": response})
157
+
158
+ # Clears the multimodal textbox on send
159
+ return gr.update(value={"text": "", "files": []}), history
160
+
161
+
162
+ # ──────────────────────────────────────────────────────────────
163
+ # 4. Gradio Interface (Unified UI)
164
+ # ──────────────────────────────────────────────────────────────
165
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
166
+ gr.Markdown("# πŸͺͺ CSM Smart Document Engine")
167
+ gr.Markdown("_Unified architecture with On-Demand Caching & Multi-Turn Chat_")
168
+
169
  with gr.Row(variant="panel"):
170
  model_dropdown = gr.Dropdown(
171
  choices=[
 
174
  "Chhagan005/CSM-DocExtract-4N",
175
  "Chhagan005/CSM-DocExtract-2N"
176
  ],
177
+ label="πŸ€– Select Model",
178
+ value="Chhagan005/CSM-KIE-Universal",
179
+ interactive=True
180
  )
181
+ status_bar = gr.Textbox(label="Memory Status", value="Select a model to load into memory", interactive=False)
182
+
183
+ # Load model dynamically when dropdown changes
184
+ model_dropdown.change(fn=ui_model_change, inputs=[model_dropdown], outputs=[status_bar])
185
 
186
  with gr.Tabs():
187
+ # TAB 1: Document Scan
188
+ with gr.TabItem("πŸ“„ Document Scanner"):
189
  with gr.Row():
190
  with gr.Column():
191
+ doc_img = gr.Image(type="pil", label="Upload ID Card")
192
+ scan_btn = gr.Button("πŸ” Extract JSON", variant="primary")
193
  with gr.Column():
194
+ json_output = gr.Code(language="json", label="Enterprise Result")
195
+ scan_btn.click(fn=run_document_scan, inputs=[doc_img, model_dropdown], outputs=[json_output])
196
+
197
+ # TAB 2: Multimodal Chat
198
+ with gr.TabItem("πŸ’¬ Intelligent Chat"):
199
+ gr.Markdown("**Tips:** Upload an image using the + icon inside the chatbox.")
200
  with gr.Row():
201
  with gr.Column(scale=1):
202
+ # Pure Gradio Chatbot (No type=tuples needed since we pass strict dicts now)
203
+ chatbot = gr.Chatbot(label="Chat History", height=450, value=[])
204
+ # Multimodal box exactly like your reference
205
+ chat_msg = gr.MultimodalTextbox(
206
+ label="Message",
207
+ placeholder="Type a message or click πŸ“Ž to upload an image...",
208
+ file_types=["image"],
209
+ submit_btn=True
210
+ )
211
 
212
+ # Submitting the Multimodal Box
213
+ chat_msg.submit(
214
+ fn=chat_fn,
215
+ inputs=[chat_msg, chatbot, model_dropdown],
216
+ outputs=[chat_msg, chatbot]
217
+ )
 
 
 
 
 
218
 
219
+ # Kickoff initialization
220
  if __name__ == "__main__":
221
  demo.launch(server_name="0.0.0.0", server_port=7860)