Chhagan005 commited on
Commit
0e9f127
·
verified ·
1 Parent(s): c3f5977

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import transforms
6
+ from huggingface_hub import hf_hub_download
7
+ import json
8
+ import string
9
+
10
+ # --- Recreate Architecture for Inference ---
11
+ # Must match the training notebook architecture
12
+ MAX_SEQ_LEN = 1500
13
+
14
+ class CSMTokenizer:
15
+ def __init__(self):
16
+ self.chars = list(string.printable) + [chr(i) for i in range(0x0600, 0x06FF + 1)]
17
+ self.PAD, self.SOS, self.EOS, self.UNK = 0, 1, 2, 3
18
+ self.vocab = {c: i+4 for i, c in enumerate(self.chars)}
19
+ self.inverse_vocab = {i+4: c for i, c in enumerate(self.chars)}
20
+ self.vocab_size = len(self.vocab) + 4
21
+
22
+ def decode(self, tokens):
23
+ return "".join([self.inverse_vocab.get(t, "") for t in tokens if t not in [self.PAD, self.SOS, self.EOS]])
24
+
25
+ class CSMVisionEncoder(nn.Module):
26
+ def __init__(self, embed_dim=256):
27
+ super().__init__()
28
+ self.cnn = nn.Sequential(
29
+ nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(32),
30
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(64),
31
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(128),
32
+ nn.Conv2d(128, embed_dim, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.BatchNorm2d(embed_dim)
33
+ )
34
+ self.pos_embed = nn.Parameter(torch.randn(1, 256, embed_dim))
35
+
36
+ def forward(self, x):
37
+ features = self.cnn(x).flatten(2).permute(0, 2, 1)
38
+ return features + self.pos_embed[:, :features.size(1), :]
39
+
40
+ class CSMJSONDecoder(nn.Module):
41
+ def __init__(self, vocab_size, embed_dim=256, num_heads=8, num_layers=4):
42
+ super().__init__()
43
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
44
+ self.pos_encoder = nn.Parameter(torch.randn(1, MAX_SEQ_LEN, embed_dim))
45
+ decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
46
+ self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
47
+ self.fc_out = nn.Linear(embed_dim, vocab_size)
48
+
49
+ def forward(self, tgt, memory):
50
+ tgt_embed = self.embedding(tgt) + self.pos_encoder[:, :tgt.size(1), :]
51
+ return self.fc_out(self.transformer(tgt_embed, memory))
52
+
53
+ class CSM_KIE_Universal(nn.Module):
54
+ def __init__(self, vocab_size):
55
+ super().__init__()
56
+ self.encoder = CSMVisionEncoder()
57
+ self.decoder = CSMJSONDecoder(vocab_size)
58
+
59
+ # --- Initialization ---
60
+ tokenizer = CSMTokenizer()
61
+ device = torch.device("cpu")
62
+
63
+ # Load Quantized Model
64
+ print("Downloading trained model...")
65
+ model_path = hf_hub_download(repo_id="Chhagan005/CSM-KIE-Universal", filename="csm_kie_model.pth")
66
+ model = CSM_KIE_Universal(tokenizer.vocab_size)
67
+ model = torch.quantization.quantize_dynamic(model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)
68
+ model.load_state_dict(torch.load(model_path, map_location=device))
69
+ model.eval()
70
+
71
+ image_transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
75
+ ])
76
+
77
+ # --- Inference Function ---
78
+ def process_id_card(front_img, back_img):
79
+ if front_img is None:
80
+ return '{"error": "Please upload at least the Front side of the ID card."}'
81
+
82
+ # Process Image
83
+ img_tensor = image_transform(front_img.convert('RGB')).unsqueeze(0)
84
+
85
+ # Autoregressive Generation Logic
86
+ generated_tokens = [tokenizer.SOS]
87
+ memory = model.encoder(img_tensor)
88
+
89
+ with torch.no_grad():
90
+ for _ in range(1000): # Max length
91
+ tgt_tensor = torch.tensor([generated_tokens], dtype=torch.long)
92
+ logits = model.decoder(tgt_tensor, memory)
93
+ next_token = logits[0, -1, :].argmax().item()
94
+ generated_tokens.append(next_token)
95
+ if next_token == tokenizer.EOS:
96
+ break
97
+
98
+ json_string = tokenizer.decode(generated_tokens)
99
+
100
+ # Format and return JSON
101
+ try:
102
+ parsed_json = json.loads(json_string)
103
+ return json.dumps(parsed_json, indent=2, ensure_ascii=False)
104
+ except:
105
+ return json_string # Fallback if model generates slight syntax error during early stages
106
+
107
+ # --- Gradio UI ---
108
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
+ gr.Markdown("# 🪪 CSM-KIE Universal ID Scanner")
110
+ gr.Markdown("Upload Front and Back sides of any International ID card (Middle East, Africa, etc.) to extract multilingual structured JSON data using the proprietary CSM-DocVL model.")
111
+
112
+ with gr.Row():
113
+ with gr.Column():
114
+ front = gr.Image(type="pil", label="Front Side (Required)")
115
+ back = gr.Image(type="pil", label="Back Side / MRZ (Optional)")
116
+ scan_btn = gr.Button("🔍 Scan & Extract JSON", variant="primary")
117
+
118
+ with gr.Column():
119
+ output_json = gr.Code(language="json", label="Structured JSON Output")
120
+
121
+ scan_btn.click(process_id_card, inputs=[front, back], outputs=output_json)
122
+
123
+ demo.launch()