credent007 commited on
Commit
33c13a8
·
verified ·
1 Parent(s): 5de77ae

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +121 -30
inference.py CHANGED
@@ -1,11 +1,89 @@
1
- import torch
2
- from model_loader import model, processor, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from processor_utils import load_input
4
  from prompt import get_prompt
5
- import json
6
- def process_document(file_path):
7
- image = load_input(file_path)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  messages = [
10
  {
11
  "role": "user",
@@ -30,40 +108,53 @@ def process_document(file_path):
30
 
31
  output = model.generate(
32
  **inputs,
33
- max_new_tokens=1500,
34
- do_sample=False, # if it is true there will be extra text with output
35
- # temperature=0.1 # temp is not required
36
  )
37
 
38
  generated_ids = output[0][inputs.input_ids.shape[-1]:]
39
 
40
- # response = processor.decode( # past code
41
- # generated_ids,
42
- # skip_special_tokens=True
43
- # )
44
-
45
- # return response.strip()
46
-
47
-
48
-
49
  response = processor.decode(
50
- generated_ids,
51
- skip_special_tokens=True
52
  ).strip()
53
 
54
- # 🔥 FORCE JSON CLEANING
55
- start = response.find("{")
56
- end = response.rfind("}") + 1
57
 
58
- if start != -1 and end != -1:
59
- response = response[start:end]
 
 
 
 
60
 
61
  try:
62
- parsed = json.loads(response)
63
- except:
64
- parsed = {
65
- "error": "Invalid JSON",
66
- "raw": response
67
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- return parsed
 
 
 
 
1
+ # import torch
2
+ # from model_loader import model, processor, device
3
+ # from processor_utils import load_input
4
+ # from prompt import get_prompt
5
+ # import json
6
+ # def process_document(file_path):
7
+ # image = load_input(file_path)
8
+
9
+ # messages = [
10
+ # {
11
+ # "role": "user",
12
+ # "content": [
13
+ # {"type": "image", "image": image},
14
+ # {"type": "text", "text": get_prompt()}
15
+ # ]
16
+ # }
17
+ # ]
18
+
19
+ # text = processor.apply_chat_template(
20
+ # messages,
21
+ # tokenize=False,
22
+ # add_generation_prompt=True
23
+ # )
24
+
25
+ # inputs = processor(
26
+ # text=[text],
27
+ # images=[image],
28
+ # return_tensors="pt"
29
+ # ).to(device)
30
+
31
+ # output = model.generate(
32
+ # **inputs,
33
+ # max_new_tokens=1500,
34
+ # do_sample=False, # if it is true there will be extra text with output
35
+ # # temperature=0.1 # temp is not required
36
+ # )
37
+
38
+ # generated_ids = output[0][inputs.input_ids.shape[-1]:]
39
+
40
+ # # response = processor.decode( # past code
41
+ # # generated_ids,
42
+ # # skip_special_tokens=True
43
+ # # )
44
+
45
+ # # return response.strip()
46
+
47
+
48
+
49
+ # response = processor.decode(
50
+ # generated_ids,
51
+ # skip_special_tokens=True
52
+ # ).strip()
53
+
54
+ # # 🔥 FORCE JSON CLEANING
55
+ # start = response.find("{")
56
+ # end = response.rfind("}") + 1
57
+
58
+ # if start != -1 and end != -1:
59
+ # response = response[start:end]
60
+
61
+ # try:
62
+ # parsed = json.loads(response)
63
+ # except:
64
+ # parsed = {
65
+ # "error": "Invalid JSON",
66
+ # "raw": response
67
+ # }
68
+
69
+ # return parsed
70
+ import json
71
+ from model_loader import get_model
72
  from processor_utils import load_input
73
  from prompt import get_prompt
 
 
 
74
 
75
+
76
+ def _extract_json_block(text):
77
+ start = text.find("{")
78
+ end = text.rfind("}") + 1
79
+
80
+ if start == -1 or end == 0:
81
+ return None
82
+
83
+ return text[start:end]
84
+
85
+
86
+ def _run_page_inference(image, model, processor, device):
87
  messages = [
88
  {
89
  "role": "user",
 
108
 
109
  output = model.generate(
110
  **inputs,
111
+ max_new_tokens=150,
112
+ do_sample=False
 
113
  )
114
 
115
  generated_ids = output[0][inputs.input_ids.shape[-1]:]
116
 
 
 
 
 
 
 
 
 
 
117
  response = processor.decode(
118
+ generated_ids,
119
+ skip_special_tokens=True
120
  ).strip()
121
 
122
+ json_block = _extract_json_block(response)
 
 
123
 
124
+ if not json_block:
125
+ return {
126
+ "status": "error",
127
+ "raw_output": response,
128
+ "parsed": None
129
+ }
130
 
131
  try:
132
+ parsed = json.loads(json_block)
133
+ return {
134
+ "status": "success",
135
+ "raw_output": response,
136
+ "parsed": parsed
137
  }
138
+ except json.JSONDecodeError:
139
+ return {
140
+ "status": "error",
141
+ "raw_output": response,
142
+ "parsed": None
143
+ }
144
+
145
+
146
+ def process_document(file_path):
147
+ model, processor, device = get_model()
148
+ pages = load_input(file_path)
149
+
150
+ page_results = []
151
+
152
+ for page_number, image in enumerate(pages, start=1):
153
+ result = _run_page_inference(image, model, processor, device)
154
+ result["page_number"] = page_number
155
+ page_results.append(result)
156
 
157
+ return {
158
+ "total_pages": len(page_results),
159
+ "pages": page_results
160
+ }