akashraut commited on
Commit
cdacb08
Β·
verified Β·
1 Parent(s): d4bebd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -47
app.py CHANGED
@@ -2,77 +2,75 @@ import gradio as gr
2
  import torch
3
  import json
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModel
6
 
7
- MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # Processor
12
- processor = AutoProcessor.from_pretrained(
13
- MODEL_ID,
14
- trust_remote_code=True
15
- )
16
 
17
- # Model (REMOTE CODE LOAD β€” critical)
18
- model = AutoModel.from_pretrained(
19
  MODEL_ID,
20
  trust_remote_code=True,
21
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
 
 
22
  device_map="auto"
23
  )
24
 
25
  model.eval()
26
 
27
  def extract_document(image: Image.Image):
28
- prompt = """
29
- You are a universal document understanding AI.
30
- Return ONLY valid JSON.
31
-
32
- Extract:
33
- - document_type
34
- - key-value fields
35
- - tables with rows and columns
36
-
37
- Be document-agnostic.
38
- Do not hallucinate.
39
- """
40
 
 
41
  inputs = processor(
42
- images=image,
43
- text=prompt,
44
- return_tensors="pt"
45
- ).to(model.device)
 
46
 
 
47
  with torch.no_grad():
48
- outputs = model.generate(
49
- **inputs,
50
- max_new_tokens=2048,
51
- temperature=0.0
52
- )
53
-
54
- text = processor.decode(outputs[0], skip_special_tokens=True)
 
 
 
55
 
56
  try:
 
57
  start = text.find("{")
58
  end = text.rfind("}") + 1
59
  return json.loads(text[start:end])
60
  except Exception:
61
- return {
62
- "error": "Model output could not be parsed",
63
- "raw_output": text
64
- }
65
 
66
  with gr.Blocks() as demo:
67
  gr.Markdown("# πŸ“„ DocAI β€” Universal Document Intelligence")
68
-
69
- image = gr.Image(type="pil", label="Upload document")
70
- output = gr.JSON(label="Extracted JSON")
71
-
72
- gr.Button("Extract").click(
73
- extract_document,
74
- inputs=image,
75
- outputs=output
76
- )
77
-
78
- demo.launch()
 
 
2
  import torch
3
  import json
4
  from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForVision2Seq
6
 
7
+ # RECOMMENDATION: If on free CPU space, use "Qwen/Qwen2-VL-2B-Instruct"
8
+ # to avoid Out-Of-Memory crashes.
9
+ MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  # Processor
14
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
15
 
16
+ # Model
17
+ model = AutoModelForVision2Seq.from_pretrained(
18
  MODEL_ID,
19
  trust_remote_code=True,
20
+ # bfloat16 is better for Qwen and uses half the memory of float32
21
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
22
+ low_cpu_mem_usage=True,
23
  device_map="auto"
24
  )
25
 
26
  model.eval()
27
 
28
  def extract_document(image: Image.Image):
29
+ if image is None:
30
+ return {"error": "No image uploaded"}
31
+
32
+ prompt = "<|im_start|>system\nYou are a universal document understanding AI. Return ONLY valid JSON.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Extract document_type, key-value fields, and tables from this document.<|im_end|>\n<|im_start|>assistant\n"
 
 
 
 
 
 
 
 
33
 
34
+ # Process image and text
35
  inputs = processor(
36
+ text=[prompt],
37
+ images=[image],
38
+ padding=True,
39
+ return_tensors="pt",
40
+ ).to(device)
41
 
42
+ # Generate
43
  with torch.no_grad():
44
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
45
+
46
+ # Trim the input tokens from the output
47
+ generated_ids_trimmed = [
48
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
49
+ ]
50
+
51
+ text = processor.batch_decode(
52
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
53
+ )[0]
54
 
55
  try:
56
+ # Extract JSON block
57
  start = text.find("{")
58
  end = text.rfind("}") + 1
59
  return json.loads(text[start:end])
60
  except Exception:
61
+ return {"raw_output": text}
 
 
 
62
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown("# πŸ“„ DocAI β€” Universal Document Intelligence")
65
+ gr.Markdown("Using Qwen2.5-VL for structured document extraction.")
66
+
67
+ with gr.Row():
68
+ with gr.Column():
69
+ image_input = gr.Image(type="pil", label="Upload document")
70
+ btn = gr.Button("Extract Data", variant="primary")
71
+ with gr.Column():
72
+ output_json = gr.JSON(label="Extracted JSON")
73
+
74
+ btn.click(extract_document, inputs=image_input, outputs=output_json)
75
+
76
+ demo.launch()