redhairedshanks1 commited on
Commit
22d9309
·
1 Parent(s): 5f78432

Update services/masterllm.py

Browse files
Files changed (1) hide show
  1. services/masterllm.py +286 -286
services/masterllm.py CHANGED
@@ -1,287 +1,287 @@
1
- # # services/masterllm.py
2
- # import json
3
- # import requests
4
- # import os
5
- # import re
6
-
7
- # # Required: set MISTRAL_API_KEY in the environment
8
- # MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
9
- # if not MISTRAL_API_KEY:
10
- # raise RuntimeError("Missing MISTRAL_API_KEY environment variable.")
11
-
12
- # MISTRAL_ENDPOINT = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1/chat/completions")
13
- # MISTRAL_MODEL = os.getenv("MISTRAL_MODEL", "mistral-small")
14
-
15
- # # Steps we support
16
- # ALLOWED_STEPS = {"text", "table", "describe", "summarize", "ner", "classify", "translate"}
17
-
18
- # def build_prompt(instruction: str) -> str:
19
- # return f"""You are a document‑processing assistant.
20
- # Return exactly one JSON object and nothing else — no markdown, no code fences, no explanation, no extra keys.
21
- # Use only the steps the user asks for in the instruction. Do not add any steps not mentioned.
22
- # Valid steps (dash‑separated): {', '.join(sorted(ALLOWED_STEPS))}
23
- # Output schema:
24
- # {{
25
- # "pipeline": "<dash‑separated‑steps>",
26
- # "tools": {{ /* object or null */ }},
27
- # "start_page": <int>,
28
- # "end_page": <int>,
29
- # "target_lang": <string or null>
30
- # }}
31
- # Instruction:
32
- # \"\"\"{instruction.strip()}\"\"\"
33
- # """
34
-
35
- # def extract_json_block(text: str) -> dict:
36
- # # Grab everything between the first { and last }
37
- # start = text.find("{")
38
- # end = text.rfind("}")
39
- # if start == -1 or end == -1:
40
- # return {"error": "no JSON braces found", "raw": text}
41
- # snippet = text[start:end + 1]
42
- # try:
43
- # return json.loads(snippet)
44
- # except json.JSONDecodeError as e:
45
- # # attempt to fix common "tools": {null} → "tools": {}
46
- # cleaned = re.sub(r'"tools"\s*:\s*\{null\}', '"tools": {}', snippet)
47
- # try:
48
- # return json.loads(cleaned)
49
- # except json.JSONDecodeError:
50
- # return {"error": f"json decode error: {e}", "raw": snippet}
51
-
52
- # def validate_pipeline(cfg: dict) -> dict:
53
- # pipe = cfg.get("pipeline")
54
- # if isinstance(pipe, list):
55
- # pipe = "-".join(pipe)
56
- # cfg["pipeline"] = pipe
57
- # if not isinstance(pipe, str):
58
- # return {"error": "pipeline must be a string"}
59
-
60
- # steps = pipe.split("-")
61
- # bad = [s for s in steps if s not in ALLOWED_STEPS]
62
- # if bad:
63
- # return {"error": f"invalid steps: {bad}"}
64
-
65
- # # translate requires target_lang
66
- # if "translate" in steps and not cfg.get("target_lang"):
67
- # return {"error": "target_lang required for translate"}
68
- # return {"ok": True}
69
-
70
- # def _sanitize_config(cfg: dict) -> dict:
71
- # # Defaults and types
72
- # try:
73
- # sp = int(cfg.get("start_page", 1))
74
- # except Exception:
75
- # sp = 1
76
- # try:
77
- # ep = int(cfg.get("end_page", sp))
78
- # except Exception:
79
- # ep = sp
80
- # if sp < 1:
81
- # sp = 1
82
- # if ep < sp:
83
- # ep = sp
84
- # cfg["start_page"] = sp
85
- # cfg["end_page"] = ep
86
-
87
- # # Ensure tools is an object
88
- # if cfg.get("tools") is None:
89
- # cfg["tools"] = {}
90
-
91
- # # Normalize pipeline separators (commas, spaces → dashes)
92
- # raw_pipe = cfg.get("pipeline", "")
93
- # steps = [s.strip() for s in re.split(r"[,\s\-]+", raw_pipe) if s.strip()]
94
- # # Deduplicate while preserving order
95
- # dedup = []
96
- # for s in steps:
97
- # if s in ALLOWED_STEPS and s not in dedup:
98
- # dedup.append(s)
99
- # cfg["pipeline"] = "-".join(dedup)
100
-
101
- # # Normalize target_lang
102
- # if "target_lang" in cfg and cfg["target_lang"] is not None:
103
- # t = str(cfg["target_lang"]).strip()
104
- # cfg["target_lang"] = t if t else None
105
-
106
- # return cfg
107
-
108
- # def generate_pipeline(instruction: str) -> dict:
109
- # prompt = build_prompt(instruction)
110
- # res = requests.post(
111
- # MISTRAL_ENDPOINT,
112
- # headers={
113
- # "Authorization": f"Bearer {MISTRAL_API_KEY}",
114
- # "Content-Type": "application/json",
115
- # },
116
- # json={
117
- # "model": MISTRAL_MODEL,
118
- # "messages": [{"role": "user", "content": prompt}],
119
- # "temperature": 0.0,
120
- # "max_tokens": 256,
121
- # },
122
- # timeout=60,
123
- # )
124
- # res.raise_for_status()
125
- # content = res.json()["choices"][0]["message"]["content"]
126
-
127
- # parsed = extract_json_block(content)
128
- # if "error" in parsed:
129
- # raise RuntimeError(f"PARSE_ERROR: {parsed['error']}\nRAW_OUTPUT:\n{parsed.get('raw', content)}")
130
-
131
- # # Sanitize and normalize
132
- # parsed = _sanitize_config(parsed)
133
-
134
- # check = validate_pipeline(parsed)
135
- # if "error" in check:
136
- # raise RuntimeError(f"PARSE_ERROR: {check['error']}\nRAW_OUTPUT:\n{content}")
137
-
138
- # return parsed
139
-
140
-
141
- # services/masterllm.py
142
- import json
143
- import os
144
- import re
145
- from typing import Dict, Any, List
146
-
147
- import requests
148
-
149
- # Google Gemini API configuration
150
- # Free tier: 15 RPM, 1M TPM, 1500 RPD for gemini-1.5-flash
151
- GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
152
- GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.0-flash")
153
- GEMINI_ENDPOINT = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent"
154
-
155
- _TOOL_TO_TOKEN = {
156
- "extract_text": "text",
157
- "extract_tables": "table",
158
- "describe_images": "describe",
159
- "summarize_text": "summarize",
160
- "classify_text": "classify",
161
- "extract_entities": "ner",
162
- "translate_text": "translate",
163
- "signature_verification": "signature",
164
- "stamp_detection": "stamp",
165
- }
166
-
167
- _ALLOWED_TOOLS = list(_TOOL_TO_TOKEN.keys())
168
-
169
-
170
- def _invoke_gemini(prompt: str) -> str:
171
- """
172
- Invoke Google Gemini API for pipeline planning.
173
- Free tier: 15 RPM, 1M TPM, 1500 RPD for gemini-1.5-flash
174
- """
175
- if not GEMINI_API_KEY:
176
- raise RuntimeError("Missing GEMINI_API_KEY or GOOGLE_API_KEY environment variable")
177
-
178
- headers = {
179
- "Content-Type": "application/json",
180
- }
181
-
182
- payload = {
183
- "contents": [{
184
- "parts": [{"text": prompt}]
185
- }],
186
- "generationConfig": {
187
- "temperature": 0.0,
188
- "maxOutputTokens": 512,
189
- }
190
- }
191
-
192
- response = requests.post(
193
- f"{GEMINI_ENDPOINT}?key={GEMINI_API_KEY}",
194
- headers=headers,
195
- json=payload,
196
- timeout=60,
197
- )
198
-
199
- if response.status_code != 200:
200
- raise RuntimeError(f"Gemini API error: {response.status_code} - {response.text}")
201
-
202
- result = response.json()
203
-
204
- # Extract text from Gemini response
205
- try:
206
- return result["candidates"][0]["content"]["parts"][0]["text"]
207
- except (KeyError, IndexError) as e:
208
- raise RuntimeError(f"Failed to parse Gemini response: {e}\nResponse: {result}")
209
-
210
-
211
- def generate_pipeline(user_instruction: str) -> Dict[str, Any]:
212
- """
213
- Produce a proposed plan as a compact pipeline string + config.
214
- Output example:
215
- {
216
- "pipeline": "text-table-summarize",
217
- "start_page": 1,
218
- "end_page": 3,
219
- "target_lang": null,
220
- "tools": ["extract_text", "extract_tables", "summarize_text"],
221
- "reason": "..."
222
- }
223
- """
224
- system_prompt = f"""You design a tool execution plan for MasterLLM.
225
- Return STRICT JSON with keys:
226
- - pipeline: string of hyphen-joined steps using tokens: text, table, describe, summarize, classify, ner, translate, signature, stamp
227
- - tools: array of tool names from: {", ".join(_ALLOWED_TOOLS)}
228
- - start_page: integer (default 1)
229
- - end_page: integer (default start_page)
230
- - target_lang: string or null
231
- - reason: short rationale
232
- Extract any page range or language from the user's request.
233
-
234
- User instruction: {user_instruction}
235
-
236
- Return only the JSON object, no markdown or explanation."""
237
-
238
- raw = _invoke_gemini(system_prompt)
239
-
240
- # best-effort JSON extraction
241
- try:
242
- data = json.loads(raw)
243
- except Exception:
244
- match = re.search(r"\{.*\}", raw, re.S)
245
- data = json.loads(match.group(0)) if match else {}
246
-
247
- # Fallbacks / validation
248
- tools: List[str] = data.get("tools") or []
249
- # Map tools -> pipeline tokens
250
- tokens = [_TOOL_TO_TOKEN[t] for t in tools if t in _TOOL_TO_TOKEN]
251
- if not tokens:
252
- # heuristic fallback
253
- text_lower = user_instruction.lower()
254
- if "table" in text_lower:
255
- tokens.append("table")
256
- if any(w in text_lower for w in ["text", "extract", "read", "content"]):
257
- tokens.insert(0, "text")
258
- if any(w in text_lower for w in ["summarize", "summary"]):
259
- tokens.append("summarize")
260
- if any(w in text_lower for w in ["translate", "spanish", "french", "german"]):
261
- tokens.append("translate")
262
- if any(w in text_lower for w in ["classify", "category", "categories"]):
263
- tokens.append("classify")
264
- if any(w in text_lower for w in ["ner", "entity", "entities"]):
265
- tokens.append("ner")
266
- if any(w in text_lower for w in ["image", "figure", "diagram", "photo"]):
267
- tokens.append("describe")
268
- pipeline = "-".join(tokens) if tokens else "text"
269
-
270
- start_page = int(data.get("start_page") or 1)
271
- end_page = int(data.get("end_page") or start_page)
272
- target_lang = data.get("target_lang") if data.get("target_lang") not in ["", "none", None] else None
273
-
274
- # if tools empty but tokens present, infer tools from tokens
275
- if not tools and tokens:
276
- inv = {v: k for k, v in _TOOL_TO_TOKEN.items()}
277
- tools = [inv[t] for t in tokens if t in inv]
278
-
279
- return {
280
- "pipeline": pipeline,
281
- "start_page": start_page,
282
- "end_page": end_page,
283
- "target_lang": target_lang,
284
- "tools": tools,
285
- "reason": data.get("reason") or "Auto-generated plan.",
286
- "raw_instruction": user_instruction,
287
  }
 
1
+ # # services/masterllm.py
2
+ # import json
3
+ # import requests
4
+ # import os
5
+ # import re
6
+
7
+ # # Required: set MISTRAL_API_KEY in the environment
8
+ # MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
9
+ # if not MISTRAL_API_KEY:
10
+ # raise RuntimeError("Missing MISTRAL_API_KEY environment variable.")
11
+
12
+ # MISTRAL_ENDPOINT = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1/chat/completions")
13
+ # MISTRAL_MODEL = os.getenv("MISTRAL_MODEL", "mistral-small")
14
+
15
+ # # Steps we support
16
+ # ALLOWED_STEPS = {"text", "table", "describe", "summarize", "ner", "classify", "translate"}
17
+
18
+ # def build_prompt(instruction: str) -> str:
19
+ # return f"""You are a document‑processing assistant.
20
+ # Return exactly one JSON object and nothing else — no markdown, no code fences, no explanation, no extra keys.
21
+ # Use only the steps the user asks for in the instruction. Do not add any steps not mentioned.
22
+ # Valid steps (dash‑separated): {', '.join(sorted(ALLOWED_STEPS))}
23
+ # Output schema:
24
+ # {{
25
+ # "pipeline": "<dash‑separated‑steps>",
26
+ # "tools": {{ /* object or null */ }},
27
+ # "start_page": <int>,
28
+ # "end_page": <int>,
29
+ # "target_lang": <string or null>
30
+ # }}
31
+ # Instruction:
32
+ # \"\"\"{instruction.strip()}\"\"\"
33
+ # """
34
+
35
+ # def extract_json_block(text: str) -> dict:
36
+ # # Grab everything between the first { and last }
37
+ # start = text.find("{")
38
+ # end = text.rfind("}")
39
+ # if start == -1 or end == -1:
40
+ # return {"error": "no JSON braces found", "raw": text}
41
+ # snippet = text[start:end + 1]
42
+ # try:
43
+ # return json.loads(snippet)
44
+ # except json.JSONDecodeError as e:
45
+ # # attempt to fix common "tools": {null} → "tools": {}
46
+ # cleaned = re.sub(r'"tools"\s*:\s*\{null\}', '"tools": {}', snippet)
47
+ # try:
48
+ # return json.loads(cleaned)
49
+ # except json.JSONDecodeError:
50
+ # return {"error": f"json decode error: {e}", "raw": snippet}
51
+
52
+ # def validate_pipeline(cfg: dict) -> dict:
53
+ # pipe = cfg.get("pipeline")
54
+ # if isinstance(pipe, list):
55
+ # pipe = "-".join(pipe)
56
+ # cfg["pipeline"] = pipe
57
+ # if not isinstance(pipe, str):
58
+ # return {"error": "pipeline must be a string"}
59
+
60
+ # steps = pipe.split("-")
61
+ # bad = [s for s in steps if s not in ALLOWED_STEPS]
62
+ # if bad:
63
+ # return {"error": f"invalid steps: {bad}"}
64
+
65
+ # # translate requires target_lang
66
+ # if "translate" in steps and not cfg.get("target_lang"):
67
+ # return {"error": "target_lang required for translate"}
68
+ # return {"ok": True}
69
+
70
+ # def _sanitize_config(cfg: dict) -> dict:
71
+ # # Defaults and types
72
+ # try:
73
+ # sp = int(cfg.get("start_page", 1))
74
+ # except Exception:
75
+ # sp = 1
76
+ # try:
77
+ # ep = int(cfg.get("end_page", sp))
78
+ # except Exception:
79
+ # ep = sp
80
+ # if sp < 1:
81
+ # sp = 1
82
+ # if ep < sp:
83
+ # ep = sp
84
+ # cfg["start_page"] = sp
85
+ # cfg["end_page"] = ep
86
+
87
+ # # Ensure tools is an object
88
+ # if cfg.get("tools") is None:
89
+ # cfg["tools"] = {}
90
+
91
+ # # Normalize pipeline separators (commas, spaces → dashes)
92
+ # raw_pipe = cfg.get("pipeline", "")
93
+ # steps = [s.strip() for s in re.split(r"[,\s\-]+", raw_pipe) if s.strip()]
94
+ # # Deduplicate while preserving order
95
+ # dedup = []
96
+ # for s in steps:
97
+ # if s in ALLOWED_STEPS and s not in dedup:
98
+ # dedup.append(s)
99
+ # cfg["pipeline"] = "-".join(dedup)
100
+
101
+ # # Normalize target_lang
102
+ # if "target_lang" in cfg and cfg["target_lang"] is not None:
103
+ # t = str(cfg["target_lang"]).strip()
104
+ # cfg["target_lang"] = t if t else None
105
+
106
+ # return cfg
107
+
108
+ # def generate_pipeline(instruction: str) -> dict:
109
+ # prompt = build_prompt(instruction)
110
+ # res = requests.post(
111
+ # MISTRAL_ENDPOINT,
112
+ # headers={
113
+ # "Authorization": f"Bearer {MISTRAL_API_KEY}",
114
+ # "Content-Type": "application/json",
115
+ # },
116
+ # json={
117
+ # "model": MISTRAL_MODEL,
118
+ # "messages": [{"role": "user", "content": prompt}],
119
+ # "temperature": 0.0,
120
+ # "max_tokens": 256,
121
+ # },
122
+ # timeout=60,
123
+ # )
124
+ # res.raise_for_status()
125
+ # content = res.json()["choices"][0]["message"]["content"]
126
+
127
+ # parsed = extract_json_block(content)
128
+ # if "error" in parsed:
129
+ # raise RuntimeError(f"PARSE_ERROR: {parsed['error']}\nRAW_OUTPUT:\n{parsed.get('raw', content)}")
130
+
131
+ # # Sanitize and normalize
132
+ # parsed = _sanitize_config(parsed)
133
+
134
+ # check = validate_pipeline(parsed)
135
+ # if "error" in check:
136
+ # raise RuntimeError(f"PARSE_ERROR: {check['error']}\nRAW_OUTPUT:\n{content}")
137
+
138
+ # return parsed
139
+
140
+
141
+ # services/masterllm.py
142
+ import json
143
+ import os
144
+ import re
145
+ from typing import Dict, Any, List
146
+
147
+ import requests
148
+
149
+ # Google Gemini API configuration
150
+ # Free tier: 15 RPM, 1M TPM, 1500 RPD for gemini-1.5-flash
151
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")
152
+ GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.0-flash")
153
+ GEMINI_ENDPOINT = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent"
154
+
155
+ _TOOL_TO_TOKEN = {
156
+ "extract_text": "text",
157
+ "extract_tables": "table",
158
+ "describe_images": "describe",
159
+ "summarize_text": "summarize",
160
+ "classify_text": "classify",
161
+ "extract_entities": "ner",
162
+ "translate_text": "translate",
163
+ "signature_verification": "signature",
164
+ "stamp_detection": "stamp",
165
+ }
166
+
167
+ _ALLOWED_TOOLS = list(_TOOL_TO_TOKEN.keys())
168
+
169
+
170
+ def _invoke_gemini(prompt: str) -> str:
171
+ """
172
+ Invoke Google Gemini API for pipeline planning.
173
+ Free tier: 15 RPM, 1M TPM, 1500 RPD for gemini-1.5-flash
174
+ """
175
+ if not GEMINI_API_KEY:
176
+ raise RuntimeError("Missing GEMINI_API_KEY or GOOGLE_API_KEY environment variable")
177
+
178
+ headers = {
179
+ "Content-Type": "application/json",
180
+ }
181
+
182
+ payload = {
183
+ "contents": [{
184
+ "parts": [{"text": prompt}]
185
+ }],
186
+ "generationConfig": {
187
+ "temperature": 0.0,
188
+ "maxOutputTokens": 512,
189
+ }
190
+ }
191
+
192
+ response = requests.post(
193
+ f"{GEMINI_ENDPOINT}?key={GEMINI_API_KEY}",
194
+ headers=headers,
195
+ json=payload,
196
+ timeout=60,
197
+ )
198
+
199
+ if response.status_code != 200:
200
+ raise RuntimeError(f"Gemini API error: {response.status_code} - {response.text}")
201
+
202
+ result = response.json()
203
+
204
+ # Extract text from Gemini response
205
+ try:
206
+ return result["candidates"][0]["content"]["parts"][0]["text"]
207
+ except (KeyError, IndexError) as e:
208
+ raise RuntimeError(f"Failed to parse Gemini response: {e}\nResponse: {result}")
209
+
210
+
211
+ def generate_pipeline(user_instruction: str) -> Dict[str, Any]:
212
+ """
213
+ Produce a proposed plan as a compact pipeline string + config.
214
+ Output example:
215
+ {
216
+ "pipeline": "text-table-summarize",
217
+ "start_page": 1,
218
+ "end_page": 3,
219
+ "target_lang": null,
220
+ "tools": ["extract_text", "extract_tables", "summarize_text"],
221
+ "reason": "..."
222
+ }
223
+ """
224
+ system_prompt = f"""You design a tool execution plan for MasterLLM.
225
+ Return STRICT JSON with keys:
226
+ - pipeline: string of hyphen-joined steps using tokens: text, table, describe, summarize, classify, ner, translate, signature, stamp
227
+ - tools: array of tool names from: {", ".join(_ALLOWED_TOOLS)}
228
+ - start_page: integer (default 1)
229
+ - end_page: integer (default start_page)
230
+ - target_lang: string or null
231
+ - reason: short rationale
232
+ Extract any page range or language from the user's request.
233
+
234
+ User instruction: {user_instruction}
235
+
236
+ Return only the JSON object, no markdown or explanation."""
237
+
238
+ raw = _invoke_gemini(system_prompt)
239
+
240
+ # best-effort JSON extraction
241
+ try:
242
+ data = json.loads(raw)
243
+ except Exception:
244
+ match = re.search(r"\{.*\}", raw, re.S)
245
+ data = json.loads(match.group(0)) if match else {}
246
+
247
+ # Fallbacks / validation
248
+ tools: List[str] = data.get("tools") or []
249
+ # Map tools -> pipeline tokens
250
+ tokens = [_TOOL_TO_TOKEN[t] for t in tools if t in _TOOL_TO_TOKEN]
251
+ if not tokens:
252
+ # heuristic fallback
253
+ text_lower = user_instruction.lower()
254
+ if "table" in text_lower:
255
+ tokens.append("table")
256
+ if any(w in text_lower for w in ["text", "extract", "read", "content"]):
257
+ tokens.insert(0, "text")
258
+ if any(w in text_lower for w in ["summarize", "summary"]):
259
+ tokens.append("summarize")
260
+ if any(w in text_lower for w in ["translate", "spanish", "french", "german"]):
261
+ tokens.append("translate")
262
+ if any(w in text_lower for w in ["classify", "category", "categories"]):
263
+ tokens.append("classify")
264
+ if any(w in text_lower for w in ["ner", "entity", "entities"]):
265
+ tokens.append("ner")
266
+ if any(w in text_lower for w in ["image", "figure", "diagram", "photo"]):
267
+ tokens.append("describe")
268
+ pipeline = "-".join(tokens) if tokens else "text"
269
+
270
+ start_page = int(data.get("start_page") or 1)
271
+ end_page = int(data.get("end_page") or start_page)
272
+ target_lang = data.get("target_lang") if data.get("target_lang") not in ["", "none", None] else None
273
+
274
+ # if tools empty but tokens present, infer tools from tokens
275
+ if not tools and tokens:
276
+ inv = {v: k for k, v in _TOOL_TO_TOKEN.items()}
277
+ tools = [inv[t] for t in tokens if t in inv]
278
+
279
+ return {
280
+ "pipeline": pipeline,
281
+ "start_page": start_page,
282
+ "end_page": end_page,
283
+ "target_lang": target_lang,
284
+ "tools": tools,
285
+ "reason": data.get("reason") or "Auto-generated plan.",
286
+ "raw_instruction": user_instruction,
287
  }