credent007 commited on
Commit
4482a85
·
verified ·
1 Parent(s): acf34d2

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -212
inference.py CHANGED
@@ -1,209 +1,3 @@
1
- # import torch
2
- # from model_loader import model, processor, device
3
- # from processor_utils import load_input
4
- # from prompt import get_prompt
5
- # import json
6
- # def process_document(image):
7
- # # images = load_input(file_path)
8
- # # image = images[0]
9
- # # print("Checking input type and no of pages in pdf")
10
- # # print(type(image))
11
- # # print(type(images))
12
- # # print(len(images))
13
-
14
-
15
-
16
- # messages = [
17
- # {
18
- # "role": "user",
19
- # "content": [
20
- # {"type": "image", "image": image},
21
- # {"type": "text", "text": get_prompt()}
22
- # ]
23
- # }
24
- # ]
25
-
26
- # text = processor.apply_chat_template(
27
- # messages,
28
- # tokenize=False, # so that this can return string output
29
- # add_generation_prompt=True # if true it will add extra on start and end
30
- # )
31
- # # print(f"The text of inference is {text}")
32
-
33
- # inputs = processor(
34
- # text=[text],
35
- # images=[image],
36
- # return_tensors="pt"
37
- # ).to(device)
38
- # # print(f"The inputs of inference is {inputs}")
39
-
40
- # output = model.generate(
41
- # **inputs,
42
- # max_new_tokens=1500,
43
- # do_sample=False, # if it is true there will be extra text with output
44
- # # temperature=0.1 # temp is not required
45
- # )
46
- # # print(f"The output of inference is {output}")
47
-
48
-
49
- # generated_ids = output[0][inputs.input_ids.shape[-1]:]
50
- # # print(f"The generated_ids of inference is {generated_ids}")
51
-
52
- # # response = processor.decode( # past code
53
- # # generated_ids,
54
- # # skip_special_tokens=True
55
- # # )
56
-
57
- # # return response.strip()
58
-
59
-
60
-
61
- # response = processor.decode(
62
- # generated_ids,
63
- # skip_special_tokens=True
64
- # ).strip()
65
- # # print(f"The response of inference is {response}")
66
-
67
- # # 🔥 FORCE JSON CLEANING
68
- # start = response.find("{")
69
- # end = response.rfind("}") + 1
70
-
71
- # if start != -1 and end != -1:
72
- # response = response[start:end]
73
-
74
- # print(f"The type of response is before{response}")
75
- # try:
76
- # parsed = json.loads(response)
77
- # except:
78
- # parsed = {
79
- # "error":[
80
- # response
81
- # ]
82
- # # "Invalid JSON",
83
- # # "raw": response
84
- # }
85
- # print(f"The type of response is after{response}")
86
-
87
- # return response
88
-
89
-
90
- # import json
91
- # from model_loader import get_model
92
- # from processor_utils import load_input
93
- # from prompt import get_part_classifier_prompt, get_part_prompt
94
-
95
-
96
- # def _run_model(image, prompt_text, model, processor, device):
97
- # messages = [
98
- # {
99
- # "role": "user",
100
- # "content": [
101
- # {"type": "image", "image": image},
102
- # {"type": "text", "text": prompt_text}
103
- # ]
104
- # }
105
- # ]
106
-
107
- # text = processor.apply_chat_template(
108
- # messages,
109
- # tokenize=False,
110
- # add_generation_prompt=True
111
- # )
112
-
113
- # inputs = processor(
114
- # text=[text],
115
- # images=[image],
116
- # return_tensors="pt"
117
- # ).to(device)
118
-
119
- # output = model.generate(
120
- # **inputs,
121
- # max_new_tokens=400,
122
- # do_sample=False
123
- # )
124
-
125
- # generated_ids = output[0][inputs.input_ids.shape[-1]:]
126
- # response = processor.decode(generated_ids, skip_special_tokens=True).strip()
127
- # return response
128
-
129
-
130
- # def _extract_json_block(text):
131
- # start = text.find("{")
132
- # end = text.rfind("}") + 1
133
- # if start == -1 or end == 0:
134
- # return None
135
- # return text[start:end]
136
-
137
-
138
- # def classify_page(image, model, processor, device):
139
- # raw = _run_model(image, get_part_classifier_prompt(), model, processor, device)
140
- # raw = raw.strip().upper()
141
-
142
- # valid_parts = {"PART-1", "PART-2", "PART-3", "PART-4", "PART-5", "PART-6"}
143
- # for part in valid_parts:
144
- # if part in raw:
145
- # return part
146
-
147
- # return "UNKNOWN"
148
-
149
-
150
- # def extract_part_json(image, part_name, model, processor, device):
151
- # raw = _run_model(image, get_part_prompt(part_name), model, processor, device)
152
- # json_block = _extract_json_block(raw)
153
-
154
- # if not json_block:
155
- # return {
156
- # "status": "error",
157
- # "part": part_name,
158
- # "raw_output": raw,
159
- # "parsed": None
160
- # }
161
-
162
- # try:
163
- # parsed = json.loads(json_block)
164
- # return {
165
- # "status": "success",
166
- # "part": part_name,
167
- # "raw_output": raw,
168
- # "parsed": parsed
169
- # }
170
- # except json.JSONDecodeError:
171
- # return {
172
- # "status": "error",
173
- # "part": part_name,
174
- # "raw_output": raw,
175
- # "parsed": None
176
- # }
177
-
178
-
179
- # def process_document(file_path):
180
- # model, processor, device = get_model()
181
- # pages = load_input(file_path)
182
-
183
- # page_results = []
184
-
185
- # for idx, image in enumerate(pages, start=1):
186
- # part_name = classify_page(image, model, processor, device)
187
-
188
- # if part_name == "UNKNOWN":
189
- # page_results.append({
190
- # "page_number": idx,
191
- # "status": "error",
192
- # "part": "UNKNOWN",
193
- # "raw_output": "",
194
- # "parsed": None
195
- # })
196
- # continue
197
-
198
- # result = extract_part_json(image, part_name, model, processor, device)
199
- # result["page_number"] = idx
200
- # page_results.append(result)
201
-
202
- # return {
203
- # "total_pages": len(page_results),
204
- # "pages": page_results
205
- # }
206
-
207
  import json
208
  from model_loader import get_model
209
  from processor_utils import load_input
@@ -458,10 +252,4 @@ def process_document(file_path):
458
  page_results.append(result)
459
 
460
  final_json = merge_page_results(page_results)
461
-
462
- # return {
463
- # "final_json": final_json
464
- # # "total_pages": len(page_results),
465
- # # "pages": page_results
466
- # }
467
  return final_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  from model_loader import get_model
3
  from processor_utils import load_input
 
252
  page_results.append(result)
253
 
254
  final_json = merge_page_results(page_results)
 
 
 
 
 
 
255
  return final_json