Toulik commited on
Commit
b83f80e
·
verified ·
1 Parent(s): 9b904c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -63
app.py CHANGED
@@ -1,70 +1,223 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
68
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import tempfile
5
+ import datetime
6
+ import re
7
+ from typing import List, Dict, Any
8
+
9
  import gradio as gr
10
+ from PIL import Image
11
+ import fitz # PyMuPDF
12
+ import pytesseract
13
+ from pdf2image import convert_from_path
14
+
15
+ import openai
16
+
17
+ # Read OpenAI key from environment (Hugging Face Spaces secrets)
18
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
19
+ if not OPENAI_API_KEY:
20
+ raise RuntimeError("OPENAI_API_KEY not found in environment. Add it to Secrets in the HF Space.")
21
+ openai.api_key = OPENAI_API_KEY
22
+
23
+ # Model config
24
+ LLM_MODEL = os.getenv("OPENAI_MODEL", "gpt-5") # change if you use a different model id
25
+ EMBEDDING_MODEL = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small") # optional
26
+
27
+ # ----------------------
28
+ # Text extraction utils
29
+ # ----------------------
30
+ def extract_text_from_pdf(path: str) -> str:
31
+ """
32
+ Try text extraction with PyMuPDF; if a page is image-only, fallback to OCR for that page.
33
+ """
34
+ doc = fitz.open(path)
35
+ texts: List[str] = []
36
+ for i in range(len(doc)):
37
+ page = doc.load_page(i)
38
+ txt = page.get_text("text").strip()
39
+ if txt:
40
+ texts.append(txt)
41
+ else:
42
+ # fallback to render page and OCR
43
+ pix = page.get_pixmap(dpi=200)
44
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
45
+ pix.save(tmp.name)
46
+ ocr_text = pytesseract.image_to_string(Image.open(tmp.name))
47
+ texts.append(ocr_text)
48
+ return "\n\n".join(texts).strip()
49
+
50
+
51
+ def extract_text_from_image(path: str) -> str:
52
+ img = Image.open(path).convert("RGB")
53
+ return pytesseract.image_to_string(img).strip()
54
+
55
+
56
+ # ----------------------
57
+ # Simple chunker
58
+ # ----------------------
59
+ def chunk_text(text: str, max_chars: int = 3000) -> List[str]:
60
+ paragraphs = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
61
+ chunks: List[str] = []
62
+ current = ""
63
+ for p in paragraphs:
64
+ if len(current) + len(p) + 2 <= max_chars:
65
+ current = (current + "\n\n" + p) if current else p
66
+ else:
67
+ if current:
68
+ chunks.append(current)
69
+ current = p
70
+ if current:
71
+ chunks.append(current)
72
+ return chunks
73
+
74
+
75
+ # ----------------------
76
+ # LLM call (strict JSON output requested)
77
+ # ----------------------
78
+ def call_gpt5_for_metadata(title: str, short_text: str, top_chunks: List[str]) -> Dict[str, Any]:
79
  """
80
+ Prompts GPT-5 to return a strict JSON object with fields matching the user's schema.
81
+ The prompt asks the model to output machine-parseable JSON only.
82
  """
83
+ # Build prompt
84
+ prompt = (
85
+ "You are an automated document taxonomy and tagging assistant for enterprise catalogs.\n\n"
86
+ f"Document title: {title}\n\n"
87
+ f"Short document text (first ~1000 chars): {short_text}\n\n"
88
+ "Top content chunks (short):\n"
89
+ )
90
+ for i, c in enumerate(top_chunks[:6]):
91
+ prompt += f"CHUNK_{i+1}: {c[:800].replace('\\n',' ')}\n\n"
92
+
93
+ prompt += (
94
+ "Task: Produce a single JSON object (machine parseable) with EXACT keys:\n"
95
+ "doc_id, title, summary, doc_type, source, tags (array of strings), tag_confidences (map tag->float), "
96
+ "taxonomy_path (array of strings), extracted_entities (map), raw_url, ingest_timestamp\n\n"
97
+ "Guidelines:\n"
98
+ "- summary: 1-2 sentences summarizing the doc.\n"
99
+ "- doc_type: short enum-like string (e.g., architecture_comparison, whitepaper, design_doc)\n"
100
+ "- tags: up to 8 short tags like arch:docai, topic:ocr-parsing\n"
101
+ "- tag_confidences: map with floats 0-1 for each tag\n"
102
+ "- taxonomy_path: hierarchical list, e.g. [\"Technology\",\"Document Processing\",\"OCR & Parsing\"]\n"
103
+ "- extracted_entities: map with keys like platforms, tools (each is an array)\n"
104
+ "- raw_url: if not available, return an empty string\n"
105
+ "- ingest_timestamp: ISO8601 with timezone (e.g., 2025-09-19T09:13:00+05:30)\n\n"
106
+ "OUTPUT: ONLY THE JSON OBJECT. DO NOT PROVIDE ANY ADDITIONAL TEXT.\n"
107
+ )
108
+
109
+ response = openai.ChatCompletion.create(
110
+ model=LLM_MODEL,
111
+ messages=[{"role": "user", "content": prompt}],
112
+ temperature=0.0,
113
+ max_tokens=1000,
114
+ )
115
+
116
+ text = response["choices"][0]["message"]["content"].strip()
117
+
118
+ # Try to extract JSON object from the reply
119
+ m = re.search(r"\{[\s\S]*\}$", text)
120
+ json_text = m.group(0) if m else text
121
+
122
+ try:
123
+ data = json.loads(json_text)
124
+ except Exception:
125
+ # If parse fails, return an error structure so UI can show the raw output
126
+ data = {"_parsing_error": True, "raw_output": text}
127
+ return data
128
+
129
+
130
+ # ----------------------
131
+ # Main processing function
132
+ # ----------------------
133
+ def process_file(file_obj) -> Dict[str, Any]:
134
+ """
135
+ file_obj: the uploaded file object provided by Gradio; has .name and a .file-like interface
136
+ Returns metadata dict ready to display.
137
+ """
138
+ # Save uploaded file to temporary path
139
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_obj.name)[1]) as tmp:
140
+ tmp.write(file_obj.read())
141
+ tmp_path = tmp.name
142
+
143
+ # Extract text
144
+ try:
145
+ if file_obj.name.lower().endswith(".pdf"):
146
+ extracted_text = extract_text_from_pdf(tmp_path)
147
+ else:
148
+ extracted_text = extract_text_from_image(tmp_path)
149
+ except Exception as e:
150
+ return {"error": f"Text extraction failed: {e}"}
151
+
152
+ if not extracted_text:
153
+ return {"error": "No text found in document after extraction."}
154
+
155
+ # Chunk and pick top chunks
156
+ chunks = chunk_text(extracted_text)
157
+ # Heuristic: pick longest chunks as representative
158
+ sorted_chunks = sorted(chunks, key=lambda x: len(x), reverse=True)
159
+ top_chunks = sorted_chunks[:6] if sorted_chunks else [extracted_text[:2000]]
160
+
161
+ # Prepare a "short_text" to feed to the LLM
162
+ short_text = (extracted_text[:1000] + "...") if len(extracted_text) > 1000 else extracted_text
163
+
164
+ # Call LLM
165
+ metadata = call_gpt5_for_metadata(file_obj.name, short_text, top_chunks)
166
+
167
+ # If LLM returned a parsing error, include it
168
+ if metadata.get("_parsing_error"):
169
+ return {
170
+ "error": "LLM output parsing failed. See raw_output.",
171
+ "raw_output": metadata.get("raw_output")
172
+ }
173
+
174
+ # Ensure required keys exist and post-process small things
175
+ now = datetime.datetime.now(datetime.timezone.utc).astimezone().isoformat()
176
+ metadata.setdefault("doc_id", os.path.splitext(file_obj.name)[0])
177
+ metadata.setdefault("title", file_obj.name)
178
+ metadata.setdefault("source", "user_upload")
179
+ metadata.setdefault("raw_url", "")
180
+ metadata.setdefault("ingest_timestamp", now)
181
+
182
+ return metadata
183
+
184
+
185
+ # ----------------------
186
+ # Gradio UI
187
+ # ----------------------
188
+ with gr.Blocks(title="DocClassify — Gradio GPT-5 Taxonomy & Tagging") as demo:
189
+ gr.Markdown("## 📂 Upload a PDF or Image — the app will classify, tag, and propose a taxonomy using GPT-5")
190
+ with gr.Row():
191
+ with gr.Column(scale=1):
192
+ uploader = gr.File(label="Upload PDF / Image", file_types=[".pdf", ".png", ".jpg", ".jpeg", ".tiff"])
193
+ run_button = gr.Button("Process document")
194
+ status = gr.Textbox(label="Status", value="", interactive=False)
195
+ download_button = gr.File(label="Download metadata JSON", visible=False)
196
+ with gr.Column(scale=1):
197
+ output_json = gr.JSON(label="Document metadata (JSON)")
198
+
199
+ def on_process(file_obj):
200
+ status.value = "Processing..."
201
+ try:
202
+ result = process_file(file_obj)
203
+ except Exception as e:
204
+ status.value = f"Failed: {e}"
205
+ return gr.update(value={}), gr.update(value="Failed: " + str(e)), None
206
+
207
+ if result.get("error"):
208
+ status.value = f"Error: {result.get('error')}"
209
+ # if raw_output provided, show under JSON
210
+ return gr.update(value={"error": result.get("error"), "raw_output": result.get("raw_output", "")}), gr.update(value=status.value), None
211
+
212
+ status.value = "Done"
213
+ # create a temp json file for download
214
+ tmpf = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
215
+ with open(tmpf.name, "w", encoding="utf8") as f:
216
+ json.dump(result, f, indent=2, ensure_ascii=False)
217
+ # gr.File expects a path - return tuple (label, path) or file object depending on gradio version
218
+ return gr.update(value=result), gr.update(value="Done"), tmpf.name
219
 
220
+ run_button.click(on_process, inputs=[uploader], outputs=[output_json, status, download_button])
221
 
222
  if __name__ == "__main__":
223
  demo.launch()