credent007 commited on
Commit
3a59309
·
verified ·
1 Parent(s): e7f2cb2

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +29 -5
inference.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from model_loader import model, processor, device
3
  from processor_utils import load_input
4
  from prompt import get_prompt
5
-
6
  def process_document(file_path):
7
  image = load_input(file_path)
8
 
@@ -37,9 +37,33 @@ def process_document(file_path):
37
 
38
  generated_ids = output[0][inputs.input_ids.shape[-1]:]
39
 
 
 
 
 
 
 
 
 
 
40
  response = processor.decode(
41
- generated_ids,
42
- skip_special_tokens=True
43
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- return response.strip()
 
2
  from model_loader import model, processor, device
3
  from processor_utils import load_input
4
  from prompt import get_prompt
5
+ import json
6
  def process_document(file_path):
7
  image = load_input(file_path)
8
 
 
37
 
38
  generated_ids = output[0][inputs.input_ids.shape[-1]:]
39
 
40
+ # response = processor.decode( # past code
41
+ # generated_ids,
42
+ # skip_special_tokens=True
43
+ # )
44
+
45
+ # return response.strip()
46
+
47
+
48
+
49
  response = processor.decode(
50
+ generated_ids,
51
+ skip_special_tokens=True
52
+ ).strip()
53
+
54
+ # 🔥 FORCE JSON CLEANING
55
+ start = response.find("{")
56
+ end = response.rfind("}") + 1
57
+
58
+ if start != -1 and end != -1:
59
+ response = response[start:end]
60
+
61
+ try:
62
+ parsed = json.loads(response)
63
+ except:
64
+ parsed = {
65
+ "error": "Invalid JSON",
66
+ "raw": response
67
+ }
68
 
69
+ return parsed