jarvisx17 commited on
Commit
d91bc32
·
verified ·
1 Parent(s): 49d1a8b

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +289 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,291 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import json
2
+ import re
 
3
  import streamlit as st
4
+ import pandas as pd
5
+ from typing import Any, List
6
+ from langchain_groq import ChatGroq
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+
13
+ # --- 1. Config ---
14
+ DEFAULT_FIELDS = [{"name": "number", "datatype": "int", "description": "Description of the item"}]
15
+ TYPE_MAPPING_STR = {"int": "int", "float": "float", "str": "str"}
16
+
17
+
18
+ def normalize_fields(fields: Any) -> List[dict]:
19
+ """Convert DataFrame/list input into a clean list of field dicts."""
20
+ try:
21
+ if isinstance(fields, pd.DataFrame):
22
+ parsed = fields.fillna("").to_dict(orient="records")
23
+ elif isinstance(fields, list):
24
+ parsed = fields
25
+ else:
26
+ return []
27
+
28
+ cleaned = []
29
+ for item in parsed:
30
+ if not isinstance(item, dict):
31
+ continue
32
+ cleaned.append(
33
+ {
34
+ "name": str(item.get("name", "")).strip(),
35
+ "datatype": str(item.get("datatype", "str")).strip() or "str",
36
+ "description": str(item.get("description", "")).strip(),
37
+ }
38
+ )
39
+ return cleaned
40
+ except Exception:
41
+ return []
42
+
43
+
44
+ def generate_schema_json(fields: Any) -> str:
45
+ """Generate JSON schema-like object from field rows."""
46
+ normalized_fields = normalize_fields(fields)
47
+ properties = {}
48
+ required = []
49
+
50
+ for f in normalized_fields:
51
+ field_name = f.get("name", "").strip()
52
+ if not field_name:
53
+ continue
54
+ dtype = TYPE_MAPPING_STR.get(f.get("datatype", "str"), "str")
55
+ properties[field_name] = {
56
+ "type": dtype,
57
+ "description": f.get("description", ""),
58
+ "nullable": True,
59
+ }
60
+ required.append(field_name)
61
+
62
+ schema = {
63
+ "type": "object",
64
+ "properties": properties,
65
+ "required": required,
66
+ "additionalProperties": False,
67
+ }
68
+ return json.dumps(schema, indent=2)
69
+
70
+
71
+ def is_valid_text(text: str) -> bool:
72
+ """Guardrail: reject empty or whitespace-only input."""
73
+ return bool((text or "").strip())
74
+
75
+
76
+ def parse_json_from_text(text: str) -> dict | None:
77
+ """Extract JSON object from model response text."""
78
+ try:
79
+ # 1) direct JSON
80
+ parsed = json.loads(text)
81
+ return parsed if isinstance(parsed, dict) else None
82
+ except Exception:
83
+ pass
84
+
85
+ try:
86
+ # 2) fenced code block
87
+ fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
88
+ if fenced:
89
+ parsed = json.loads(fenced.group(1))
90
+ return parsed if isinstance(parsed, dict) else None
91
+ except Exception:
92
+ pass
93
+
94
+ try:
95
+ # 3) first object-looking block
96
+ obj = re.search(r"(\{.*\})", text, flags=re.DOTALL)
97
+ if obj:
98
+ parsed = json.loads(obj.group(1))
99
+ return parsed if isinstance(parsed, dict) else None
100
+ except Exception:
101
+ pass
102
+
103
+ return None
104
+
105
+
106
+ def cast_to_dtype(value: Any, dtype: str) -> Any:
107
+ if value is None:
108
+ return None
109
+ try:
110
+ if dtype == "int":
111
+ return int(value)
112
+ if dtype == "float":
113
+ return float(value)
114
+ return str(value)
115
+ except Exception:
116
+ return None
117
+
118
+
119
+ def extract_structured(fields: Any, unstructured_text: str) -> dict | str:
120
+ """
121
+ Extract structured data from unstructured text based on user-defined fields.
122
+
123
+ Args:
124
+ fields: A list of dicts or a pd.DataFrame with columns
125
+ [name, datatype, description].
126
+ unstructured_text: Raw text to extract data from.
127
+
128
+ Returns:
129
+ A JSON dict on success, or an error string.
130
+ """
131
+ if not is_valid_text(unstructured_text):
132
+ return "Input text is empty. Please provide some text to extract from."
133
+
134
+ # Build schema from user-defined fields
135
+ normalized_fields = normalize_fields(fields)
136
+ schema_properties = {}
137
+ field_order = []
138
+
139
+ for f in normalized_fields:
140
+ field_name = f.get("name", "").strip()
141
+ if not field_name:
142
+ continue
143
+ if not field_name.isidentifier():
144
+ return f"Invalid field name '{field_name}'. Use letters, numbers, and underscores only."
145
+ field_type = TYPE_MAPPING_STR.get(f.get("datatype", "str"), "str")
146
+ schema_properties[field_name] = {
147
+ "type": field_type,
148
+ "description": f.get("description", ""),
149
+ }
150
+ field_order.append(field_name)
151
+
152
+ if not schema_properties:
153
+ return "Please add at least one valid field before extraction."
154
+
155
+ # Initialize LLM
156
+ llm = ChatGroq(
157
+ model="openai/gpt-oss-120b",
158
+ temperature=0,
159
+ api_key=os.getenv("GROQ_API_KEY"),
160
+ )
161
+
162
+ # Extract structured data
163
+ try:
164
+ schema_json = json.dumps(schema_properties, indent=2)
165
+ response = llm.invoke(
166
+ "Extract information from the text below.\n"
167
+ "Return ONLY one valid JSON object and no extra text.\n"
168
+ "Use exactly the fields in this schema.\n"
169
+ "If a value is missing, return null.\n\n"
170
+ f"Schema:\n{schema_json}\n\n"
171
+ f"Text:\n{unstructured_text}"
172
+ )
173
+ content = response.content if hasattr(response, "content") else str(response)
174
+ if isinstance(content, list):
175
+ content = "".join(
176
+ part.get("text", "") if isinstance(part, dict) else str(part)
177
+ for part in content
178
+ )
179
+
180
+ parsed = parse_json_from_text(str(content))
181
+ if not parsed:
182
+ return f"Could not parse JSON from model output: {content}"
183
+
184
+ # Coerce output to requested schema and order
185
+ cleaned = {}
186
+ for field_name in field_order:
187
+ dtype = schema_properties[field_name]["type"]
188
+ cleaned[field_name] = cast_to_dtype(parsed.get(field_name), dtype)
189
+ return cleaned
190
+
191
+ except Exception as e:
192
+ return f"Error during extraction: {str(e)}"
193
+
194
+
195
+ def render_styles():
196
+ st.markdown(
197
+ """
198
+ <style>
199
+ .main-title {
200
+ font-size: 34px;
201
+ font-weight: 700;
202
+ margin-bottom: 4px;
203
+ }
204
+ .sub-title {
205
+ color: #6b7280;
206
+ margin-bottom: 20px;
207
+ }
208
+ .block-header {
209
+ font-size: 22px;
210
+ font-weight: 600;
211
+ margin: 8px 0 8px 0;
212
+ }
213
+ </style>
214
+ """,
215
+ unsafe_allow_html=True,
216
+ )
217
+
218
+
219
+ def main():
220
+ st.set_page_config(page_title="Dynamic Extraction", layout="wide")
221
+ render_styles()
222
+
223
+ st.markdown('<div class="main-title">Dynamic Invoice Extraction</div>', unsafe_allow_html=True)
224
+ st.markdown('<div class="sub-title">Json structured output</div>', unsafe_allow_html=True)
225
+
226
+ if "fields_df" not in st.session_state:
227
+ st.session_state.fields_df = pd.DataFrame(DEFAULT_FIELDS)
228
+ if "generated_schema" not in st.session_state:
229
+ st.session_state.generated_schema = ""
230
+ if "structured_result" not in st.session_state:
231
+ st.session_state.structured_result = ""
232
+ if "structured_result_json" not in st.session_state:
233
+ st.session_state.structured_result_json = {}
234
+
235
+ left_col, right_col = st.columns(2)
236
+
237
+ with left_col:
238
+ st.markdown('<div class="block-header">Define Entities / Fields</div>', unsafe_allow_html=True)
239
+ if st.button("+ Add Field", width="stretch"):
240
+ st.session_state.fields_df = pd.concat(
241
+ [st.session_state.fields_df, pd.DataFrame([{"name": "", "datatype": "str", "description": ""}])],
242
+ ignore_index=True,
243
+ )
244
+
245
+ edited_df = st.data_editor(
246
+ st.session_state.fields_df,
247
+ width="stretch",
248
+ num_rows="dynamic",
249
+ column_config={
250
+ "name": st.column_config.TextColumn("name"),
251
+ "datatype": st.column_config.SelectboxColumn("datatype", options=["str", "int", "float"]),
252
+ "description": st.column_config.TextColumn("description"),
253
+ },
254
+ key="fields_editor",
255
+ )
256
+ st.session_state.fields_df = edited_df
257
+
258
+ st.markdown('<div class="block-header">Paste Unstructured Text</div>', unsafe_allow_html=True)
259
+ unstructured_text = st.text_area(
260
+ "Example: https://huggingface.co/spaces/opendatalab/MinerU",
261
+ "Click on the above link and extract the mqarkdown text from that page and paste it here...",
262
+ placeholder="Paste your text here...",
263
+ height=220,
264
+ )
265
+ if st.button("Extract Structured Data", type="primary", width="stretch"):
266
+ with st.spinner("Extracting structured data..."):
267
+ result = extract_structured(st.session_state.fields_df, unstructured_text)
268
+ if isinstance(result, dict):
269
+ st.session_state.structured_result_json = result
270
+ st.session_state.structured_result = ""
271
+ else:
272
+ st.session_state.structured_result_json = {}
273
+ st.session_state.structured_result = result
274
+
275
+ with right_col:
276
+ st.markdown("### Structured Output (Transposed Table)")
277
+ if st.session_state.structured_result_json:
278
+ transposed_df = (
279
+ pd.DataFrame([st.session_state.structured_result_json])
280
+ .T.reset_index()
281
+ .rename(columns={"index": "Field", 0: "Value"})
282
+ )
283
+ st.dataframe(transposed_df, width="stretch", hide_index=True)
284
+ elif st.session_state.structured_result:
285
+ st.error(st.session_state.structured_result)
286
+ else:
287
+ st.info("Run extraction to see transposed table output.")
288
+
289
 
290
+ if __name__ == "__main__":
291
+ main()