credent007 commited on
Commit
5b807f0
·
verified ·
1 Parent(s): f26ff7c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +117 -117
inference.py CHANGED
@@ -1,89 +1,11 @@
1
- # import torch
2
- # from model_loader import model, processor, device
3
- # from processor_utils import load_input
4
- # from prompt import get_prompt
5
- # import json
6
- # def process_document(file_path):
7
- # image = load_input(file_path)
8
-
9
- # messages = [
10
- # {
11
- # "role": "user",
12
- # "content": [
13
- # {"type": "image", "image": image},
14
- # {"type": "text", "text": get_prompt()}
15
- # ]
16
- # }
17
- # ]
18
-
19
- # text = processor.apply_chat_template(
20
- # messages,
21
- # tokenize=False,
22
- # add_generation_prompt=True
23
- # )
24
-
25
- # inputs = processor(
26
- # text=[text],
27
- # images=[image],
28
- # return_tensors="pt"
29
- # ).to(device)
30
-
31
- # output = model.generate(
32
- # **inputs,
33
- # max_new_tokens=1500,
34
- # do_sample=False, # if it is true there will be extra text with output
35
- # # temperature=0.1 # temp is not required
36
- # )
37
-
38
- # generated_ids = output[0][inputs.input_ids.shape[-1]:]
39
-
40
- # # response = processor.decode( # past code
41
- # # generated_ids,
42
- # # skip_special_tokens=True
43
- # # )
44
-
45
- # # return response.strip()
46
-
47
-
48
-
49
- # response = processor.decode(
50
- # generated_ids,
51
- # skip_special_tokens=True
52
- # ).strip()
53
-
54
- # # 🔥 FORCE JSON CLEANING
55
- # start = response.find("{")
56
- # end = response.rfind("}") + 1
57
-
58
- # if start != -1 and end != -1:
59
- # response = response[start:end]
60
-
61
- # try:
62
- # parsed = json.loads(response)
63
- # except:
64
- # parsed = {
65
- # "error": "Invalid JSON",
66
- # "raw": response
67
- # }
68
-
69
- # return parsed
70
- import json
71
- from model_loader import get_model
72
  from processor_utils import load_input
73
  from prompt import get_prompt
 
 
 
74
 
75
-
76
- def _extract_json_block(text):
77
- start = text.find("{")
78
- end = text.rfind("}") + 1
79
-
80
- if start == -1 or end == 0:
81
- return None
82
-
83
- return text[start:end]
84
-
85
-
86
- def _run_page_inference(image, model, processor, device):
87
  messages = [
88
  {
89
  "role": "user",
@@ -108,53 +30,131 @@ def _run_page_inference(image, model, processor, device):
108
 
109
  output = model.generate(
110
  **inputs,
111
- max_new_tokens=150,
112
- do_sample=False
 
113
  )
114
 
115
  generated_ids = output[0][inputs.input_ids.shape[-1]:]
116
 
 
 
 
 
 
 
 
 
 
117
  response = processor.decode(
118
- generated_ids,
119
- skip_special_tokens=True
120
  ).strip()
121
 
122
- json_block = _extract_json_block(response)
 
 
123
 
124
- if not json_block:
125
- return {
126
- "status": "error",
127
- "raw_output": response,
128
- "parsed": None
129
- }
130
 
131
  try:
132
- parsed = json.loads(json_block)
133
- return {
134
- "status": "success",
135
- "raw_output": response,
136
- "parsed": parsed
137
- }
138
- except json.JSONDecodeError:
139
- return {
140
- "status": "error",
141
- "raw_output": response,
142
- "parsed": None
143
  }
144
 
 
 
 
 
 
 
145
 
146
- def process_document(file_path):
147
- model, processor, device = get_model()
148
- pages = load_input(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- page_results = []
151
 
152
- for page_number, image in enumerate(pages, start=1):
153
- result = _run_page_inference(image, model, processor, device)
154
- result["page_number"] = page_number
155
- page_results.append(result)
156
 
157
- return {
158
- "total_pages": len(page_results),
159
- "pages": page_results
160
- }
 
1
+ import torch
2
+ from model_loader import model, processor, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from processor_utils import load_input
4
  from prompt import get_prompt
5
+ import json
6
+ def process_document(file_path):
7
+ image = load_input(file_path)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  messages = [
10
  {
11
  "role": "user",
 
30
 
31
  output = model.generate(
32
  **inputs,
33
+ max_new_tokens=1500,
34
+ do_sample=False, # if it is true there will be extra text with output
35
+ # temperature=0.1 # temp is not required
36
  )
37
 
38
  generated_ids = output[0][inputs.input_ids.shape[-1]:]
39
 
40
+ # response = processor.decode( # past code
41
+ # generated_ids,
42
+ # skip_special_tokens=True
43
+ # )
44
+
45
+ # return response.strip()
46
+
47
+
48
+
49
  response = processor.decode(
50
+ generated_ids,
51
+ skip_special_tokens=True
52
  ).strip()
53
 
54
+ # 🔥 FORCE JSON CLEANING
55
+ start = response.find("{")
56
+ end = response.rfind("}") + 1
57
 
58
+ if start != -1 and end != -1:
59
+ response = response[start:end]
 
 
 
 
60
 
61
  try:
62
+ parsed = json.loads(response)
63
+ except:
64
+ parsed = {
65
+ "error": "Invalid JSON",
66
+ "raw": response
 
 
 
 
 
 
67
  }
68
 
69
+ return parsed
70
+ # import json
71
+ # from model_loader import get_model
72
+ # from processor_utils import load_input
73
+ # from prompt import get_prompt
74
+
75
 
76
+ # def _extract_json_block(text):
77
+ # start = text.find("{")
78
+ # end = text.rfind("}") + 1
79
+
80
+ # if start == -1 or end == 0:
81
+ # return None
82
+
83
+ # return text[start:end]
84
+
85
+
86
+ # def _run_page_inference(image, model, processor, device):
87
+ # messages = [
88
+ # {
89
+ # "role": "user",
90
+ # "content": [
91
+ # {"type": "image", "image": image},
92
+ # {"type": "text", "text": get_prompt()}
93
+ # ]
94
+ # }
95
+ # ]
96
+
97
+ # text = processor.apply_chat_template(
98
+ # messages,
99
+ # tokenize=False,
100
+ # add_generation_prompt=True
101
+ # )
102
+
103
+ # inputs = processor(
104
+ # text=[text],
105
+ # images=[image],
106
+ # return_tensors="pt"
107
+ # ).to(device)
108
+
109
+ # output = model.generate(
110
+ # **inputs,
111
+ # max_new_tokens=150,
112
+ # do_sample=False
113
+ # )
114
+
115
+ # generated_ids = output[0][inputs.input_ids.shape[-1]:]
116
+
117
+ # response = processor.decode(
118
+ # generated_ids,
119
+ # skip_special_tokens=True
120
+ # ).strip()
121
+
122
+ # json_block = _extract_json_block(response)
123
+
124
+ # if not json_block:
125
+ # return {
126
+ # "status": "error",
127
+ # "raw_output": response,
128
+ # "parsed": None
129
+ # }
130
+
131
+ # try:
132
+ # parsed = json.loads(json_block)
133
+ # return {
134
+ # "status": "success",
135
+ # "raw_output": response,
136
+ # "parsed": parsed
137
+ # }
138
+ # except json.JSONDecodeError:
139
+ # return {
140
+ # "status": "error",
141
+ # "raw_output": response,
142
+ # "parsed": None
143
+ # }
144
+
145
+
146
+ # def process_document(file_path):
147
+ # model, processor, device = get_model()
148
+ # pages = load_input(file_path)
149
 
150
+ # page_results = []
151
 
152
+ # for page_number, image in enumerate(pages, start=1):
153
+ # result = _run_page_inference(image, model, processor, device)
154
+ # result["page_number"] = page_number
155
+ # page_results.append(result)
156
 
157
+ # return {
158
+ # "total_pages": len(page_results),
159
+ # "pages": page_results
160
+ # }