haharvs commited on
Commit
df03a2f
·
verified ·
1 Parent(s): ed973c8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +185 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,187 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import requests
3
+ import json
4
+ import base64
5
 
6
+ # Configuration
7
+ API_URL = "https://openrouter.ai/api/v1/chat/completions"
8
+
9
+ def encode_image(uploaded_file):
10
+ """Encodes the uploaded file to base64."""
11
+ bytes_data = uploaded_file.getvalue()
12
+ return base64.b64encode(bytes_data).decode('utf-8')
13
+
14
+ def analyze_receipt(base64_image, prompt_text, api_key):
15
+ """Sends the image to OpenRouter API for analysis."""
16
+ if not api_key:
17
+ return {"error": "API Key is missing."}
18
+
19
+ headers = {
20
+ "Authorization": f"Bearer {api_key}",
21
+ "Content-Type": "application/json"
22
+ }
23
+
24
+ data_url = f"data:image/jpeg;base64,{base64_image}"
25
+
26
+ messages = [
27
+ {
28
+ "role": "user",
29
+ "content": [
30
+ {
31
+ "type": "text",
32
+ "text": prompt_text
33
+ },
34
+ {
35
+ "type": "image_url",
36
+ "image_url": {
37
+ "url": data_url
38
+ }
39
+ }
40
+ ]
41
+ }
42
+ ]
43
+
44
+ payload = {
45
+ "model": "qwen/qwen3-vl-8b-instruct",
46
+ "messages": messages
47
+ }
48
+
49
+ try:
50
+ response = requests.post(API_URL, headers=headers, json=payload)
51
+ response.raise_for_status() # Raise an error for bad status codes
52
+ return response.json()
53
+ except requests.exceptions.RequestException as e:
54
+ return {"error": str(e)}
55
+
56
+ # Streamlit App UI
57
+ st.set_page_config(page_title="Receipt Analyzer", page_icon="🧾", layout="wide")
58
+
59
+ st.title("🧾 Receipt Cost Breakdown (Qwen 3-VL-8B)")
60
+ st.markdown("Upload a receipt image to get a JSON breakdown of costs.")
61
+
62
+ # Sidebar for configuration
63
+ with st.sidebar:
64
+ st.header("⚙️ Configuration")
65
+
66
+ # API Key Input for User Inference
67
+ st.subheader("API Access")
68
+ api_key = st.text_input("OpenRouter API Key", type="password", help="Enter your OpenRouter API Key here.")
69
+ if not api_key:
70
+ st.warning("Please enter your API Key to proceed.")
71
+
72
+ st.divider()
73
+
74
+ # User-friendly schema builder - PRIORITY 1
75
+ # st.markdown("Define what to extract from the receipt.") # Removed to save space
76
+ st.subheader("Fields to Extract")
77
+
78
+ default_fields = ["Merchant Name", "Total Amount", "Currency", "Date"]
79
+ available_fields = ["Merchant Name", "Total Amount", "Currency", "Date", "Tax/VAT", "Address", "Time", "Payment Method"]
80
+
81
+ selected_fields = st.multiselect(
82
+ "Select fields:",
83
+ options=available_fields,
84
+ default=default_fields,
85
+ help="Leave empty to extract **ALL** available information automatically."
86
+ )
87
+
88
+ if not selected_fields:
89
+ st.caption("✅ *No fields selected. The model will extract everything it finds.*")
90
+
91
+ extract_line_items = st.checkbox("Extract Line Items (Name & Price)", value=True)
92
+
93
+ st.divider()
94
+
95
+ # Custom instructions - PRIORITY 2
96
+ custom_instructions = st.text_input(
97
+ "Custom Instructions (Optional)",
98
+ placeholder="e.g., Extract the cashier name",
99
+ help="Add any specific data points or rules not covered above."
100
+ )
101
+
102
+ st.divider()
103
+
104
+ # Model Indicator - MOVED TO BOTTOM
105
+ with st.expander("ℹ️ About the Model", expanded=False):
106
+ st.info(
107
+ "**Qwen 3-VL-8B**\n\n"
108
+ "This is an open-source model efficient enough to run locally on consumer hardware."
109
+ )
110
+
111
+ # Construct the prompt dynamically
112
+ if not selected_fields:
113
+ # User selected nothing -> Extract all
114
+ prompt_text = "Analyze this receipt image. Extract **all** visible information including merchant details, dates, totals, taxes, and address in a structured JSON format."
115
+ else:
116
+ # User selected specific fields
117
+ field_str = ", ".join(selected_fields)
118
+ prompt_text = f"Analyze this receipt image. Extract the following information in JSON format: {field_str}."
119
+
120
+ if extract_line_items:
121
+ prompt_text += " Also include a detailed list of 'items' containing 'name' and 'price'."
122
+
123
+ if custom_instructions:
124
+ prompt_text += f" Additionally: {custom_instructions}."
125
+
126
+ # Enforce JSON structure
127
+ prompt_text += " Return a single valid JSON object. Do not include markdown formatting."
128
+
129
+ # Store in variable to match existing function call
130
+ custom_prompt = prompt_text
131
+
132
+ uploaded_file = st.file_uploader("Choose a receipt image...", type=["jpg", "jpeg", "png"])
133
+
134
+ if uploaded_file is not None:
135
+ col1, col2 = st.columns(2)
136
+
137
+ with col1:
138
+ # Display the uploaded image
139
+ st.image(uploaded_file, caption="Uploaded Receipt", use_column_width=True)
140
+ analyze = st.button("Analyze Receipt", type="primary", use_container_width=True)
141
+
142
+ with col2:
143
+ if analyze:
144
+ if not api_key:
145
+ st.error("Please enter an API Key in the sidebar.")
146
+ else:
147
+ with st.spinner("Analyzing receipt..."):
148
+ # Encode image
149
+ base64_image = encode_image(uploaded_file)
150
+
151
+ # Call API
152
+ api_result = analyze_receipt(base64_image, custom_prompt, api_key)
153
+
154
+ # Handle response
155
+ if "error" in api_result:
156
+ st.error(f"Error calling API: {api_result['error']}")
157
+ elif "choices" in api_result:
158
+ content = api_result["choices"][0]["message"]["content"]
159
+
160
+ with st.expander("🔍 Raw Analysis Output"):
161
+ st.code(content, language="json")
162
+
163
+ # Try to clean and parse JSON if markdown code blocks are used
164
+ try:
165
+ # Clean up code blocks if present
166
+ json_str = content.replace("```json", "").replace("```", "").strip()
167
+ parsed_json = json.loads(json_str)
168
+
169
+ st.success("Analysis Complete!")
170
+ st.subheader("Structured Data")
171
+ st.json(parsed_json)
172
+
173
+ # Optional: Display as a nice table
174
+ if "items" in parsed_json and isinstance(parsed_json["items"], list):
175
+ st.subheader("Itemized Breakdown")
176
+ st.dataframe(parsed_json["items"], use_container_width=True)
177
+
178
+ # Display other top-level keys as metrics if simple
179
+ for key, value in parsed_json.items():
180
+ if key != "items" and isinstance(value, (int, float, str)):
181
+ st.metric(key.title(), value)
182
+
183
+ except json.JSONDecodeError:
184
+ st.warning("Could not parse the response as JSON. See the raw output above.")
185
+ else:
186
+ st.error("Unexpected response format from API.")
187
+ st.json(api_result)