Nick-2x commited on
Commit
6785323
·
verified ·
1 Parent(s): f398029

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -33
app.py CHANGED
@@ -1,42 +1,97 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- # from transformers import AutoProcessor, AutoModelForVision2Seq,AutoModel
3
- from transformers import AutoProcessor,AutoModel
4
- from PIL import Image
5
- import torch
6
- import io
7
 
8
- app = FastAPI()
9
 
10
- MODEL_ID = "zai-org/GLM-OCR"
11
 
12
- print("Loading GLM-OCR model...")
13
 
14
- # processor = AutoProcessor.from_pretrained(MODEL_ID)
15
- # model = AutoModelForVision2Seq.from_pretrained(
16
- # MODEL_ID,
17
- # torch_dtype=torch.float32
18
- # )
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # processor = AutoProcessor.from_pretrained(
21
  # MODEL_ID,
22
  # trust_remote_code=True
23
  # )
24
 
25
- # model = AutoModelForVision2Seq.from_pretrained(
26
  # MODEL_ID,
27
- # trust_remote_code=True,
28
- # torch_dtype=torch.float32
29
  # )
30
 
31
- processor = AutoProcessor.from_pretrained(
32
- MODEL_ID,
33
- trust_remote_code=True
34
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- model = AutoModel.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  MODEL_ID,
38
- trust_remote_code=True
39
- )
 
40
 
41
  @app.get("/")
42
  async def root():
@@ -45,24 +100,48 @@ async def root():
45
  @app.post("/ocr")
46
  async def extract_text(file: UploadFile = File(...)):
47
  try:
 
48
  contents = await file.read()
49
  image = Image.open(io.BytesIO(contents)).convert("RGB")
50
 
51
- # inputs = processor(images=image, return_tensors="pt")
52
- inputs = processor(
53
- text="Extract all text from the document",
54
- images=image,
55
- return_tensors="pt"
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
58
  with torch.no_grad():
59
- outputs = model.generate(**inputs, max_new_tokens=1024)
 
 
 
 
60
 
61
- text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
 
62
 
63
  return {
64
  "success": True,
65
- "text": text
66
  }
67
 
68
  except Exception as e:
 
1
+ # from fastapi import FastAPI, UploadFile, File
2
+ # # from transformers import AutoProcessor, AutoModelForVision2Seq,AutoModel
3
+ # from transformers import AutoProcessor,AutoModel
4
+ # from PIL import Image
5
+ # import torch
6
+ # import io
7
 
8
+ # app = FastAPI()
9
 
10
+ # MODEL_ID = "zai-org/GLM-OCR"
11
 
12
+ # print("Loading GLM-OCR model...")
13
 
14
+ # # processor = AutoProcessor.from_pretrained(MODEL_ID)
15
+ # # model = AutoModelForVision2Seq.from_pretrained(
16
+ # # MODEL_ID,
17
+ # # torch_dtype=torch.float32
18
+ # # )
19
+
20
+ # # processor = AutoProcessor.from_pretrained(
21
+ # # MODEL_ID,
22
+ # # trust_remote_code=True
23
+ # # )
24
+
25
+ # # model = AutoModelForVision2Seq.from_pretrained(
26
+ # # MODEL_ID,
27
+ # # trust_remote_code=True,
28
+ # # torch_dtype=torch.float32
29
+ # # )
30
 
31
  # processor = AutoProcessor.from_pretrained(
32
  # MODEL_ID,
33
  # trust_remote_code=True
34
  # )
35
 
36
+ # model = AutoModel.from_pretrained(
37
  # MODEL_ID,
38
+ # trust_remote_code=True
 
39
  # )
40
 
41
+ # @app.get("/")
42
+ # async def root():
43
+ # return {"status": "GLM-OCR API is running"}
44
+
45
+ # @app.post("/ocr")
46
+ # async def extract_text(file: UploadFile = File(...)):
47
+ # try:
48
+ # contents = await file.read()
49
+ # image = Image.open(io.BytesIO(contents)).convert("RGB")
50
+
51
+ # # inputs = processor(images=image, return_tensors="pt")
52
+ # inputs = processor(
53
+ # text="Extract all text from the document",
54
+ # images=image,
55
+ # return_tensors="pt"
56
+ # )
57
+
58
+ # with torch.no_grad():
59
+ # outputs = model.generate(**inputs, max_new_tokens=1024)
60
 
61
+ # text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
62
+
63
+ # return {
64
+ # "success": True,
65
+ # "text": text
66
+ # }
67
+
68
+ # except Exception as e:
69
+ # return {
70
+ # "success": False,
71
+ # "error": str(e)
72
+ # }
73
+
74
+
75
+
76
+ from fastapi import FastAPI, UploadFile, File
77
+ from transformers import AutoProcessor, GlmOcrForConditionalGeneration
78
+ from PIL import Image
79
+ import torch
80
+ import io
81
+
82
+ app = FastAPI()
83
+
84
+ MODEL_ID = "zai-org/GLM-OCR"
85
+
86
+ print("Loading GLM-OCR model...")
87
+
88
+ # Initialize Processor and Model specifically for GLM-OCR
89
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
90
+ model = GlmOcrForConditionalGeneration.from_pretrained(
91
  MODEL_ID,
92
+ trust_remote_code=True,
93
+ torch_dtype=torch.float32 # Use torch.bfloat16 if you have a GPU
94
+ ).eval()
95
 
96
  @app.get("/")
97
  async def root():
 
100
  @app.post("/ocr")
101
  async def extract_text(file: UploadFile = File(...)):
102
  try:
103
+ # Read and prepare image
104
  contents = await file.read()
105
  image = Image.open(io.BytesIO(contents)).convert("RGB")
106
 
107
+ # 1. Define the conversation structure
108
+ messages = [
109
+ {
110
+ "role": "user",
111
+ "content": [
112
+ {"type": "image"},
113
+ {"type": "text", "text": "Extract all text from this image."}
114
+ ],
115
+ }
116
+ ]
117
+
118
+ # 2. Use the chat template to prepare inputs
119
+ # This fixes the 'NoneType' error by providing valid input_ids
120
+ inputs = processor.apply_chat_template(
121
+ messages,
122
+ images=[image],
123
+ tokenize=True,
124
+ add_generation_prompt=True,
125
+ return_dict=True,
126
+ return_tensors="pt"
127
+ )
128
 
129
+ # 3. Generate
130
  with torch.no_grad():
131
+ outputs = model.generate(
132
+ **inputs,
133
+ max_new_tokens=1024,
134
+ do_sample=False
135
+ )
136
 
137
+ # 4. Decode the result
138
+ # We slice the output to remove the prompt tokens and keep only the response
139
+ generated_ids = outputs[:, inputs['input_ids'].shape[1]:]
140
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
141
 
142
  return {
143
  "success": True,
144
+ "text": text.strip()
145
  }
146
 
147
  except Exception as e: