Seth0330 commited on
Commit
7e5da41
·
verified ·
1 Parent(s): 0ce55d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -99
app.py CHANGED
@@ -20,13 +20,13 @@ MODELS = {
20
  "api_url": "https://api.deepseek.com/v1/chat/completions",
21
  "model_name": "deepseek-chat",
22
  "api_key_env": "DEEPSEEK_API_KEY",
23
- "response_format": {"type": "json_object"}
24
  },
25
  "DeepSeek R1": {
26
  "api_url": "https://api.deepseek.com/v1/chat/completions",
27
  "model_name": "deepseek-reasoner",
28
  "api_key_env": "DEEPSEEK_API_KEY",
29
- "response_format": None
30
  },
31
  "Llama 4 Mavericks": {
32
  "api_url": "https://openrouter.ai/api/v1/chat/completions",
@@ -35,8 +35,8 @@ MODELS = {
35
  "response_format": {"type": "json_object"},
36
  "extra_headers": {
37
  "HTTP-Referer": "https://huggingface.co",
38
- "X-Title": "Invoice Extractor"
39
- }
40
  },
41
  "Mistral Small": {
42
  "api_url": "https://openrouter.ai/api/v1/chat/completions",
@@ -45,160 +45,197 @@ MODELS = {
45
  "response_format": {"type": "json_object"},
46
  "extra_headers": {
47
  "HTTP-Referer": "https://huggingface.co",
48
- "X-Title": "Invoice Extractor"
49
- }
50
- }
51
  }
52
 
53
  def get_api_key(model_choice):
54
- api_key = os.environ.get(MODELS[model_choice]["api_key_env"])
55
- if not api_key:
56
- st.error(f"❌ {MODELS[model_choice]['api_key_env']} environment variable not set!")
57
  st.stop()
58
- return api_key
59
 
60
  def query_llm(model_choice, prompt):
61
- config = MODELS[model_choice]
62
  headers = {
63
  "Authorization": f"Bearer {get_api_key(model_choice)}",
64
  "Content-Type": "application/json",
65
  }
66
- if config.get("extra_headers"):
67
- headers.update(config["extra_headers"])
68
 
69
  payload = {
70
- "model": config["model_name"],
71
  "messages": [{"role": "user", "content": prompt}],
72
  "temperature": 0.1,
73
  "max_tokens": 2000,
74
  }
75
- if config.get("response_format"):
76
- payload["response_format"] = config["response_format"]
77
 
78
  try:
79
- with st.spinner(f"🔍 Analyzing with {model_choice}..."):
80
- resp = requests.post(config["api_url"], headers=headers, json=payload, timeout=90)
81
- if resp.status_code != 200:
82
- st.error(f"🚨 API Error {resp.status_code}: {resp.text}")
83
- return None
84
- content = resp.json()["choices"][0]["message"]["content"]
85
- st.session_state.last_api_response = content
86
- st.session_state.last_api_response_raw = resp.text
87
- return content
88
- except requests.exceptions.RequestException as e:
89
- st.error(f"🌐 Connection Failed: {e}")
90
  return None
91
 
92
  def clean_json_response(text):
93
- """Strip code fences and extract a valid JSON segment."""
94
  if not text:
95
  return None
96
  original = text
97
- # Remove any ``` or ```json fences
98
- text = re.sub(r'```(?:json)?', '', text)
99
- text = text.strip()
100
-
101
- # Find the JSON object boundaries
102
  start = text.find('{')
103
  end = text.rfind('}') + 1
104
- if start == -1 or end == 0:
105
- st.error("Failed to locate JSON in the response.")
106
  st.code(original)
107
  return None
108
- json_str = text[start:end]
109
-
110
  try:
111
- return json.loads(json_str)
112
  except json.JSONDecodeError as e:
113
- st.error(f"JSON decode error: {e}")
114
- st.code(json_str)
115
  return None
116
 
117
  def get_extraction_prompt(model_choice, text):
118
- # (Prompts abbreviated here for readability—use your existing prompt definitions)
119
  if model_choice == "DeepSeek v3":
120
- return "..." # your DeepSeek v3 prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  elif model_choice == "DeepSeek R1":
122
- return "..." # your DeepSeek R1 prompt
123
- else:
124
- return "..." # generic Llama/Mistral prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  def extract_invoice_info(model_choice, text):
127
  prompt = get_extraction_prompt(model_choice, text)
128
- result = query_llm(model_choice, prompt)
129
- if not result:
130
  return None
131
- data = clean_json_response(result)
 
 
 
 
 
132
  if not data:
133
  return None
134
 
135
- # Normalize structure
136
  if model_choice in ["Llama 4 Mavericks", "Mistral Small"]:
137
- header = data.setdefault("invoice_header", {})
138
- for key in ["invoice_number", "invoice_date", "po_number", "invoice_value", "supplier_name", "customer_name"]:
139
- header.setdefault(key, None)
140
  items = data.setdefault("line_items", [])
141
- for item in items:
142
- for key in ["item_number", "description", "quantity", "unit_price", "total_price"]:
143
- item.setdefault(key, None)
144
  else:
145
- for key in ["invoice_number", "invoice_date", "po_number", "invoice_value"]:
146
- data.setdefault(key, None)
147
  items = data.setdefault("line_items", [])
148
- for item in items:
149
- for key in ["description", "quantity", "unit_price", "total_price"]:
150
- item.setdefault(key, None)
151
 
152
  return data
153
 
154
- # ---- UI Layout ----
155
  tab1, tab2 = st.tabs(["PDF Summarizer", "Invoice Extractor"])
156
 
157
  with tab1:
158
- st.title("PDF to Bullet Point Summarizer 🗟")
159
- pdf_file = st.file_uploader("Upload PDF", type="pdf")
160
- scale = st.slider("Summarization extent (%)", 1, 100, 20)
161
- if st.button("Generate Summary") and pdf_file:
162
- text = read_pdf(io.BytesIO(pdf_file.getvalue()))
163
- phrases = extract_key_phrases(text)
164
- scores = score_sentences(text, phrases)
165
- count = max(1, len(scores) * scale // 100)
166
- summary = summarize_text(scores, num_points=count)
167
- st.subheader("Summary:")
168
- st.markdown(summary)
169
 
170
  with tab2:
171
- st.title("📋 Invoice Extractor from PDF")
172
- model_choice = st.selectbox("Select AI Model", list(MODELS.keys()))
173
- invoice_pdf = st.file_uploader("Upload Invoice PDF", type="pdf")
174
- if st.button("Extract Invoice Information") and invoice_pdf:
175
- invoice_text = read_pdf(io.BytesIO(invoice_pdf.getvalue()))
176
- invoice_data = extract_invoice_info(model_choice, invoice_text)
177
- if invoice_data:
178
- st.success("Extraction Complete!")
179
- if model_choice in ["Llama 4 Mavericks", "Mistral Small"]:
180
- hdr = invoice_data["invoice_header"]
181
  c1, c2, c3 = st.columns(3)
182
- c1.metric("Invoice #", hdr.get("invoice_number"))
183
- c1.metric("Supplier", hdr.get("supplier_name"))
184
- c2.metric("Date", hdr.get("invoice_date"))
185
- c2.metric("Customer", hdr.get("customer_name"))
186
- c3.metric("PO #", hdr.get("po_number"))
187
- c3.metric("Total", hdr.get("invoice_value"))
188
  st.subheader("Line Items")
189
- st.table(invoice_data["line_items"])
190
  else:
191
  c1, c2 = st.columns(2)
192
- c1.metric("Invoice #", invoice_data.get("invoice_number"))
193
- c1.metric("PO #", invoice_data.get("po_number"))
194
- c2.metric("Date", invoice_data.get("invoice_date"))
195
- c2.metric("Value", invoice_data.get("invoice_value"))
196
  st.subheader("Line Items")
197
- st.table(invoice_data["line_items"])
198
 
199
  if "last_api_response" in st.session_state:
200
- with st.expander("Debug Information"):
201
- st.write("Extracted content (raw string):")
202
  st.code(st.session_state.last_api_response)
203
- st.write("Full HTTP response text:")
204
- st.code(st.session_state.get("last_api_response_raw", "No response"))
 
20
  "api_url": "https://api.deepseek.com/v1/chat/completions",
21
  "model_name": "deepseek-chat",
22
  "api_key_env": "DEEPSEEK_API_KEY",
23
+ "response_format": {"type": "json_object"},
24
  },
25
  "DeepSeek R1": {
26
  "api_url": "https://api.deepseek.com/v1/chat/completions",
27
  "model_name": "deepseek-reasoner",
28
  "api_key_env": "DEEPSEEK_API_KEY",
29
+ "response_format": None,
30
  },
31
  "Llama 4 Mavericks": {
32
  "api_url": "https://openrouter.ai/api/v1/chat/completions",
 
35
  "response_format": {"type": "json_object"},
36
  "extra_headers": {
37
  "HTTP-Referer": "https://huggingface.co",
38
+ "X-Title": "Invoice Extractor",
39
+ },
40
  },
41
  "Mistral Small": {
42
  "api_url": "https://openrouter.ai/api/v1/chat/completions",
 
45
  "response_format": {"type": "json_object"},
46
  "extra_headers": {
47
  "HTTP-Referer": "https://huggingface.co",
48
+ "X-Title": "Invoice Extractor",
49
+ },
50
+ },
51
  }
52
 
53
  def get_api_key(model_choice):
54
+ key = os.environ.get(MODELS[model_choice]["api_key_env"])
55
+ if not key:
56
+ st.error(f"❌ {MODELS[model_choice]['api_key_env']} not set")
57
  st.stop()
58
+ return key
59
 
60
  def query_llm(model_choice, prompt):
61
+ cfg = MODELS[model_choice]
62
  headers = {
63
  "Authorization": f"Bearer {get_api_key(model_choice)}",
64
  "Content-Type": "application/json",
65
  }
66
+ if cfg.get("extra_headers"):
67
+ headers.update(cfg["extra_headers"])
68
 
69
  payload = {
70
+ "model": cfg["model_name"],
71
  "messages": [{"role": "user", "content": prompt}],
72
  "temperature": 0.1,
73
  "max_tokens": 2000,
74
  }
75
+ if cfg.get("response_format"):
76
+ payload["response_format"] = cfg["response_format"]
77
 
78
  try:
79
+ with st.spinner(f"🔍 Querying {model_choice}..."):
80
+ resp = requests.post(cfg["api_url"], headers=headers, json=payload, timeout=90)
81
+ if resp.status_code != 200:
82
+ st.error(f"🚨 API Error {resp.status_code}: {resp.text}")
83
+ return None
84
+ content = resp.json()["choices"][0]["message"]["content"]
85
+ st.session_state.last_api_response = content
86
+ st.session_state.last_api_raw = resp.text
87
+ return content
88
+ except Exception as e:
89
+ st.error(f"Connection failed: {e}")
90
  return None
91
 
92
  def clean_json_response(text):
 
93
  if not text:
94
  return None
95
  original = text
96
+ # remove any ``` fences
97
+ text = re.sub(r'```(?:json)?', '', text).strip()
98
+ # find outer braces
 
 
99
  start = text.find('{')
100
  end = text.rfind('}') + 1
101
+ if start < 0 or end < 1:
102
+ st.error("Couldn't locate JSON in response.")
103
  st.code(original)
104
  return None
105
+ fragment = text[start:end]
 
106
  try:
107
+ return json.loads(fragment)
108
  except json.JSONDecodeError as e:
109
+ st.error(f"JSON parse error: {e}")
110
+ st.code(fragment)
111
  return None
112
 
113
  def get_extraction_prompt(model_choice, text):
114
+ # NOTE: every prompt below includes the word "json" in lowercase
115
  if model_choice == "DeepSeek v3":
116
+ return (
117
+ "Extract complete invoice information and return ONLY a valid json object with these fields:\n"
118
+ "{\n"
119
+ ' "invoice_number": "string",\n'
120
+ ' "invoice_date": "YYYY-MM-DD",\n'
121
+ ' "po_number": "string or null",\n'
122
+ ' "invoice_value": "string with currency symbol",\n'
123
+ ' "line_items": [\n'
124
+ " {...}\n"
125
+ " ]\n"
126
+ "}\n"
127
+ "Rules:\n"
128
+ "1. Use null for missing fields\n"
129
+ "2. Do not include any additional text\n\n"
130
+ "Invoice Text:\n"
131
+ + text
132
+ )
133
+
134
  elif model_choice == "DeepSeek R1":
135
+ return (
136
+ "Please extract invoice info from the text below and return only raw json:\n"
137
+ "{...}\n"
138
+ "Invoice Text:\n"
139
+ + text
140
+ )
141
+
142
+ else: # Llama / Mistral
143
+ return (
144
+ "Extract complete invoice information and return a valid json object with these fields:\n"
145
+ "{\n"
146
+ ' "invoice_header": {...},\n'
147
+ ' "line_items": [...]\n'
148
+ "}\n"
149
+ "Rules:\n"
150
+ "1. Return ONLY json\n"
151
+ "2. Date format YYYY-MM-DD\n"
152
+ "3. Currency values with symbol\n"
153
+ "4. Do not include any explanations\n\n"
154
+ "Invoice Text:\n"
155
+ + text
156
+ )
157
 
158
  def extract_invoice_info(model_choice, text):
159
  prompt = get_extraction_prompt(model_choice, text)
160
+ raw = query_llm(model_choice, prompt)
161
+ if raw is None:
162
  return None
163
+ if not raw.strip():
164
+ st.error("Empty response from API.")
165
+ st.code(st.session_state.last_api_raw)
166
+ return None
167
+
168
+ data = clean_json_response(raw)
169
  if not data:
170
  return None
171
 
172
+ # normalize
173
  if model_choice in ["Llama 4 Mavericks", "Mistral Small"]:
174
+ hdr = data.setdefault("invoice_header", {})
175
+ for k in ["invoice_number", "invoice_date", "po_number", "invoice_value", "supplier_name", "customer_name"]:
176
+ hdr.setdefault(k, None)
177
  items = data.setdefault("line_items", [])
178
+ for itm in items:
179
+ for k in ["item_number", "description", "quantity", "unit_price", "total_price"]:
180
+ itm.setdefault(k, None)
181
  else:
182
+ for k in ["invoice_number", "invoice_date", "po_number", "invoice_value"]:
183
+ data.setdefault(k, None)
184
  items = data.setdefault("line_items", [])
185
+ for itm in items:
186
+ for k in ["description", "quantity", "unit_price", "total_price"]:
187
+ itm.setdefault(k, None)
188
 
189
  return data
190
 
191
+ # ---- UI ----
192
  tab1, tab2 = st.tabs(["PDF Summarizer", "Invoice Extractor"])
193
 
194
  with tab1:
195
+ st.title("PDF to Bullet Point Summarizer")
196
+ pdf = st.file_uploader("Upload PDF", type="pdf")
197
+ pct = st.slider("Summarization (%)", 1, 100, 20)
198
+ if st.button("Summarize") and pdf:
199
+ txt = read_pdf(io.BytesIO(pdf.getvalue()))
200
+ keys = extract_key_phrases(txt)
201
+ scores = score_sentences(txt, keys)
202
+ n = max(1, len(scores) * pct // 100)
203
+ bullet = summarize_text(scores, num_points=n)
204
+ st.subheader("Summary")
205
+ st.markdown(bullet)
206
 
207
  with tab2:
208
+ st.title("Invoice Extractor")
209
+ mdl = st.selectbox("Model", list(MODELS.keys()))
210
+ inv_pdf = st.file_uploader("Invoice PDF", type="pdf")
211
+ if st.button("Extract") and inv_pdf:
212
+ txt = read_pdf(io.BytesIO(inv_pdf.getvalue()))
213
+ info = extract_invoice_info(mdl, txt)
214
+ if info:
215
+ st.success("Done")
216
+ if mdl in ["Llama 4 Mavericks", "Mistral Small"]:
217
+ h = info["invoice_header"]
218
  c1, c2, c3 = st.columns(3)
219
+ c1.metric("Invoice #", h["invoice_number"])
220
+ c1.metric("Supplier", h["supplier_name"])
221
+ c2.metric("Date", h["invoice_date"])
222
+ c2.metric("Customer", h["customer_name"])
223
+ c3.metric("PO #", h["po_number"])
224
+ c3.metric("Total", h["invoice_value"])
225
  st.subheader("Line Items")
226
+ st.table(info["line_items"])
227
  else:
228
  c1, c2 = st.columns(2)
229
+ c1.metric("Invoice #", info["invoice_number"])
230
+ c1.metric("PO #", info["po_number"])
231
+ c2.metric("Date", info["invoice_date"])
232
+ c2.metric("Value", info["invoice_value"])
233
  st.subheader("Line Items")
234
+ st.table(info["line_items"])
235
 
236
  if "last_api_response" in st.session_state:
237
+ with st.expander("Debug"):
238
+ st.write("Raw assistant content:")
239
  st.code(st.session_state.last_api_response)
240
+ st.write("Full HTTP response:")
241
+ st.code(st.session_state.last_api_raw)