redhairedshanks1 commited on
Commit
a13ddef
·
verified ·
1 Parent(s): 1c6aa49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -9
app.py CHANGED
@@ -13,20 +13,50 @@
13
  # demo.launch()
14
 
15
  from PIL import Image
16
- from huggingface_hub import snapshot_download
17
- from transformers import AutoProcessor, AutoModelForCausalLM
18
  import gradio as gr
19
  import torch
 
 
20
 
 
21
  MODEL_ID = "rednote-hilab/dots.ocr"
22
- local = snapshot_download(MODEL_ID)
23
- model = AutoModelForCausalLM.from_pretrained(local, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16)
24
- processor = AutoProcessor.from_pretrained(local, trust_remote_code=True)
25
 
26
- def parse_document(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  inputs = processor(images=[image], return_tensors="pt").to(model.device)
28
- output = model.generate(**inputs, do_sample=False, max_new_tokens=1024)
 
 
 
 
 
29
  return processor.batch_decode(output, skip_special_tokens=True)[0]
30
 
31
- demo = gr.Interface(parse_document, inputs=gr.Image(type="pil"), outputs="text")
32
- if __name__ == "__main__": demo.launch()
 
 
 
 
 
 
 
 
 
 
 
13
  # demo.launch()
14
 
15
  from PIL import Image
 
 
16
  import gradio as gr
17
  import torch
18
+ from huggingface_hub import snapshot_download
19
+ from transformers import AutoProcessor, AutoModelForCausalLM
20
 
21
+ # Model ID
22
  MODEL_ID = "rednote-hilab/dots.ocr"
 
 
 
23
 
24
+ # Download snapshot locally
25
+ local_model_path = snapshot_download(MODEL_ID)
26
+
27
+ # Load model & processor
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ local_model_path,
30
+ trust_remote_code=True,
31
+ device_map="auto",
32
+ torch_dtype=torch.bfloat16
33
+ )
34
+
35
+ processor = AutoProcessor.from_pretrained(
36
+ local_model_path,
37
+ trust_remote_code=True
38
+ )
39
+
40
+ # OCR parsing function
41
+ def parse_document(image: Image.Image):
42
  inputs = processor(images=[image], return_tensors="pt").to(model.device)
43
+ with torch.no_grad():
44
+ output = model.generate(
45
+ **inputs,
46
+ do_sample=False,
47
+ max_new_tokens=1024
48
+ )
49
  return processor.batch_decode(output, skip_special_tokens=True)[0]
50
 
51
+ # Gradio UI
52
+ demo = gr.Interface(
53
+ fn=parse_document,
54
+ inputs=gr.Image(type="pil", label="Upload Document"),
55
+ outputs=gr.Textbox(label="Extracted Text"),
56
+ title="Dots OCR Demo",
57
+ description="Upload an image or scanned document to extract text using rednote-hilab/dots.ocr"
58
+ )
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch()
62
+