from __future__ import annotations import json import os from typing import Annotated, Any, Literal import gradio as gr from app import _log_call_end, _log_call_start, _truncate_for_log from ._core import _resolve_path from ._docstrings import autodoc TOOL_SUMMARY = ( "Scrape and extract structured data from known URLs using ScrapeGraphAI with " "Mistral-only models. Supports single-page extraction, bounded crawl extraction, " "multi-URL extraction, rendered markdown, and image-aware extraction." ) ACTION_CHOICES = [ "extract", "crawl_extract", "multi_extract", "render_markdown", "vision_extract", ] RENDER_CHOICES = ["auto", "browser", "http"] TEXT_MODEL_ENV = "SCRAPEGRAPH_TEXT_MODEL" VISION_MODEL_ENV = "SCRAPEGRAPH_VISION_MODEL" DEFAULT_TEXT_MODEL = "mistral-small-latest" DEFAULT_VISION_MODEL = "pixtral-12b-latest" _IMPORT_ERROR: Exception | None = None try: from langchain.chat_models import init_chat_model from pydantic import BaseModel, Field, create_model from scrapegraphai.graphs import SmartScraperGraph, SmartScraperMultiGraph from scrapegraphai.graphs.abstract_graph import AbstractGraph from scrapegraphai.graphs.base_graph import BaseGraph from scrapegraphai.nodes import ( DescriptionNode, FetchNode, FetchNodeLevelK, GenerateAnswerNodeKLevel, GenerateAnswerOmniNode, ImageToTextNode, ParseNode, ParseNodeDepthK, RAGNode, ) from scrapegraphai.utils.convert_to_md import convert_to_md except Exception as exc: # pragma: no cover - import error path is runtime-only _IMPORT_ERROR = exc init_chat_model = None BaseModel = None Field = None create_model = None SmartScraperGraph = None SmartScraperMultiGraph = None AbstractGraph = None BaseGraph = None DescriptionNode = None FetchNode = None FetchNodeLevelK = None GenerateAnswerNodeKLevel = None GenerateAnswerOmniNode = None ImageToTextNode = None ParseNode = None ParseNodeDepthK = None RAGNode = None convert_to_md = None else: class _LimitedFetchNodeLevelK(FetchNodeLevelK): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_pages = None if self.node_config is None else self.node_config.get("max_pages") def obtain_content(self, documents, loader_kwargs): documents = super().obtain_content(documents, loader_kwargs) if self.max_pages and len(documents) > self.max_pages: return documents[: self.max_pages] return documents class _BoundedDepthSearchGraph(AbstractGraph): def __init__(self, prompt: str, source: str, config: dict, schema: type[BaseModel] | None = None): super().__init__(prompt, config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" def _create_graph(self): fetch_node_k = _LimitedFetchNodeLevelK( input="url| local_dir", output=["docs"], node_config={ "loader_kwargs": self.config.get("loader_kwargs", {}), "force": self.config.get("force", False), "cut": self.config.get("cut", True), "browser_base": self.config.get("browser_base"), "storage_state": self.config.get("storage_state"), "depth": self.config.get("depth", 1), "only_inside_links": self.config.get("only_inside_links", False), "max_pages": self.config.get("max_pages"), }, ) parse_node_k = ParseNodeDepthK( input="docs", output=["docs"], node_config={"verbose": self.config.get("verbose", False)}, ) description_node = DescriptionNode( input="docs", output=["docs"], node_config={ "llm_model": self.llm_model, "verbose": self.config.get("verbose", False), "cache_path": self.config.get("cache_path", False), }, ) rag_node = RAGNode( input="docs", output=["vectorial_db"], node_config={ "llm_model": self.llm_model, "embedder_model": self.config.get("embedder_model", False), "verbose": self.config.get("verbose", False), }, ) generate_answer_k = GenerateAnswerNodeKLevel( input="vectorial_db", output=["answer"], node_config={ "llm_model": self.llm_model, "embedder_model": self.config.get("embedder_model", False), "verbose": self.config.get("verbose", False), "schema": self.schema, }, ) return BaseGraph( nodes=[fetch_node_k, parse_node_k, description_node, rag_node, generate_answer_k], edges=[ (fetch_node_k, parse_node_k), (parse_node_k, description_node), (description_node, rag_node), (rag_node, generate_answer_k), ], entry_point=fetch_node_k, graph_name=self.__class__.__name__, ) def run(self): inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) return self.final_state.get("answer", "No answer found.") class _MistralOmniScraperGraph(AbstractGraph): def __init__(self, prompt: str, source: str, config: dict, schema: type[BaseModel] | None = None): self.max_images = config.get("max_images", 5) super().__init__(prompt, config, source, schema) self.input_key = "url" if source.startswith("http") else "local_dir" def _create_graph(self): vision_model = init_chat_model( model=self.config.get("vision_model", DEFAULT_VISION_MODEL), model_provider="mistralai", api_key=self.config["llm"]["api_key"], temperature=0, ) fetch_node = FetchNode( input="url | local_dir", output=["doc"], node_config={ "loader_kwargs": self.config.get("loader_kwargs", {}), "storage_state": self.config.get("storage_state"), "use_soup": self.config.get("use_soup", False), "timeout": self.config.get("timeout", 30), }, ) parse_node = ParseNode( input="doc & (url | local_dir)", output=["parsed_doc", "link_urls", "img_urls"], node_config={ "chunk_size": self.model_token, "parse_urls": True, "llm_model": self.llm_model, }, ) image_to_text_node = ImageToTextNode( input="img_urls", output=["img_desc"], node_config={ "llm_model": vision_model, "max_images": self.max_images, }, ) generate_answer_omni_node = GenerateAnswerOmniNode( input="user_prompt & (relevant_chunks | parsed_doc | doc) & img_desc", output=["answer"], node_config={ "llm_model": self.llm_model, "additional_info": self.config.get("additional_info"), "schema": self.schema, }, ) return BaseGraph( nodes=[fetch_node, parse_node, image_to_text_node, generate_answer_omni_node], edges=[ (fetch_node, parse_node), (parse_node, image_to_text_node), (image_to_text_node, generate_answer_omni_node), ], entry_point=fetch_node, graph_name=self.__class__.__name__, ) def run(self): inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) return self.final_state.get("answer", "No answer found.") class ScrapeGraphToolError(RuntimeError): def __init__(self, code: str, message: str, hint: str | None = None): super().__init__(message) self.code = code self.message = message self.hint = hint def _json_response(payload: dict[str, Any]) -> str: return json.dumps(payload, ensure_ascii=False, indent=2, default=str) def _error_response(action: str, code: str, message: str, hint: str | None = None) -> str: return _json_response( { "action": action, "error": {"code": code, "message": message, **({"hint": hint} if hint else {})}, } ) def _require_scrapegraph() -> None: if _IMPORT_ERROR is not None: raise ScrapeGraphToolError( "missing_scrapegraph_dependencies", f"ScrapeGraphAI dependencies are unavailable: {_IMPORT_ERROR}", "Install `scrapegraphai>=1.75.1` and its runtime dependencies.", ) def _require_mistral_key() -> str: api_key = os.getenv("MISTRAL_API_KEY", "").strip() if not api_key: raise ScrapeGraphToolError( "missing_mistral_api_key", "MISTRAL_API_KEY is not configured.", "Set MISTRAL_API_KEY in the environment before using ScrapeGraphAI extraction actions.", ) return api_key def _coerce_urls(urls: Any) -> list[str]: if urls is None or urls == "": return [] if isinstance(urls, list): return [str(url).strip() for url in urls if str(url).strip()] if isinstance(urls, str): text = urls.strip() if not text: return [] if text.startswith("["): parsed = json.loads(text) if not isinstance(parsed, list): raise ScrapeGraphToolError("invalid_urls", "urls must be a JSON array of URL strings.") return [str(url).strip() for url in parsed if str(url).strip()] return [part.strip() for part in text.replace("\r", "\n").replace(",", "\n").split("\n") if part.strip()] raise ScrapeGraphToolError("invalid_urls", "urls must be provided as a list or JSON array string.") def _coerce_schema(schema_json: Any) -> dict[str, Any] | None: if schema_json in (None, "", {}): return None if isinstance(schema_json, dict): return schema_json if isinstance(schema_json, str): try: parsed = json.loads(schema_json) except json.JSONDecodeError as exc: raise ScrapeGraphToolError("invalid_schema_json", f"schema_json is not valid JSON: {exc}") from exc if not isinstance(parsed, dict): raise ScrapeGraphToolError("invalid_schema_json", "schema_json must decode to a JSON object.") return parsed raise ScrapeGraphToolError("invalid_schema_json", "schema_json must be a JSON object or JSON string.") def _schema_to_type(name: str, schema: dict[str, Any]) -> Any: schema_type = schema.get("type") if schema_type == "string": return str if schema_type == "integer": return int if schema_type == "number": return float if schema_type == "boolean": return bool if schema_type == "array": item_schema = schema.get("items", {}) return list[_schema_to_type(f"{name}Item", item_schema)] if schema_type == "object" or "properties" in schema: properties = schema.get("properties", {}) required = set(schema.get("required", [])) fields: dict[str, tuple[Any, Any]] = {} for prop_name, prop_schema in properties.items(): prop_type = _schema_to_type(f"{name}{prop_name.title()}", prop_schema) description = prop_schema.get("description") is_required = prop_name in required annotation = prop_type if is_required else (prop_type | None) default = Field(... if is_required else None, description=description) fields[prop_name] = (annotation, default) return create_model(name, **fields) return Any def _schema_to_model(schema: dict[str, Any] | None) -> type[BaseModel] | None: if not schema: return None if schema.get("type") not in (None, "object") and "properties" not in schema: raise ScrapeGraphToolError( "invalid_schema_json", "Only object-shaped JSON schemas are supported for schema_json.", ) model_type = _schema_to_type("ScrapeGraphResult", schema) if not isinstance(model_type, type) or not issubclass(model_type, BaseModel): raise ScrapeGraphToolError( "invalid_schema_json", "schema_json must define an object with properties for structured extraction.", ) return model_type def _resolve_storage_state(storage_state_path: str | None) -> str | None: if not storage_state_path: return None candidate = storage_state_path.strip() if not candidate: return None if os.path.isabs(candidate): resolved = candidate else: resolved, _ = _resolve_path(candidate) if not os.path.exists(resolved): raise ScrapeGraphToolError( "invalid_storage_state_path", f"Storage state file not found: {candidate}", ) return resolved def _build_config( *, api_key: str | None, text_model: str | None = None, render_mode: str = "auto", timeout_s: int = 30, storage_state_path: str | None = None, depth: int | None = None, max_pages: int | None = None, same_domain_only: bool | None = None, max_images: int | None = None, vision_model: str | None = None, ) -> dict[str, Any]: if render_mode not in RENDER_CHOICES: raise ScrapeGraphToolError("invalid_render_mode", f"Unsupported render_mode: {render_mode}") config: dict[str, Any] = { "headless": True, "verbose": False, "timeout": max(5, int(timeout_s)), "use_soup": render_mode == "http", } if api_key: config["llm"] = { "api_key": api_key, "model": f"mistralai/{text_model or os.getenv(TEXT_MODEL_ENV, DEFAULT_TEXT_MODEL)}", "temperature": 0, } if storage_state_path: config["storage_state"] = storage_state_path if depth is not None: config["depth"] = max(1, int(depth)) if max_pages is not None: config["max_pages"] = max(1, int(max_pages)) if same_domain_only is not None: config["only_inside_links"] = bool(same_domain_only) if max_images is not None: config["max_images"] = max(1, int(max_images)) if vision_model: config["vision_model"] = vision_model return config def _json_safe(value: Any) -> Any: if BaseModel is not None and isinstance(value, BaseModel): return value.model_dump(mode="json") if isinstance(value, dict): return {key: _json_safe(val) for key, val in value.items()} if isinstance(value, list): return [_json_safe(item) for item in value] if isinstance(value, tuple): return [_json_safe(item) for item in value] if hasattr(value, "metadata") and hasattr(value, "page_content"): return { "page_content": getattr(value, "page_content", ""), "metadata": _json_safe(getattr(value, "metadata", {})), } if isinstance(value, str): stripped = value.strip() if stripped.startswith("{") or stripped.startswith("["): try: return json.loads(stripped) except Exception: return value return value def _extract_sources(state: dict[str, Any], fallback: list[str] | None = None) -> list[str]: sources: list[str] = [] for item in state.get("docs", []) or []: source = item.get("source") if isinstance(item, dict) else None if source and source not in sources: sources.append(source) for doc in state.get("doc", []) or []: metadata = getattr(doc, "metadata", {}) or {} source = metadata.get("source") if source and source not in sources: sources.append(source) if not sources and fallback: sources.extend([source for source in fallback if source]) return sources def _extract_links_and_images(doc_state: dict[str, Any], url: str) -> tuple[list[str], list[str]]: parse_node = ParseNode( input="doc & url", output=["parsed_doc", "link_urls", "img_urls"], node_config={ "parse_urls": True, "parse_html": True, "chunk_size": 8192, "llm_model": None, }, ) docs = doc_state.get("doc") if not docs: docs = doc_state.get("html_content", []) if not docs: return [], [] state = {"doc": docs, "url": url} parse_node.execute(state) return state.get("link_urls", []) or [], state.get("img_urls", []) or [] def _render_markdown_with_fetch(url: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[dict[str, Any]]]: fetch_node = FetchNode( input="url", output=["doc"], node_config=config, ) state = {"url": url} state = fetch_node.execute(state) docs = state.get("doc", []) or [] if not docs: raise ScrapeGraphToolError("fetch_failed", "ScrapeGraph fetch returned no documents for render_markdown.") html = getattr(docs[0], "page_content", None) or "" if not html.strip(): raise ScrapeGraphToolError("fetch_failed", "Fetched document for render_markdown had empty content.") state["markdown"] = convert_to_md(html) return state, [] @autodoc(summary=TOOL_SUMMARY) def ScrapeGraphAI( action: Annotated[ Literal["extract", "crawl_extract", "multi_extract", "render_markdown", "vision_extract"], "Action to run: extract, crawl_extract, multi_extract, render_markdown, or vision_extract.", ] = "extract", url: Annotated[str, "Single URL for extract, crawl_extract, render_markdown, or vision_extract."] = "", urls: Annotated[list[str] | str | None, "Explicit list of URLs for multi_extract. Accepts a list or JSON array string."] = None, prompt: Annotated[str, "Natural-language extraction prompt. Required for extraction actions."] = "", schema_json: Annotated[dict[str, Any] | str | None, "Optional object-shaped JSON schema for structured extraction."] = None, render_mode: Annotated[Literal["auto", "browser", "http"], "Fetch mode. `browser` uses ScrapeGraph browser loading, `http` uses requests + soup, `auto` currently follows ScrapeGraph's browser-first path."] = "auto", include_images: Annotated[bool, "For `extract`, include page images in the extraction context."] = False, depth: Annotated[int, "For `crawl_extract`, crawl depth from the starting URL."] = 1, max_pages: Annotated[int, "For `crawl_extract`, soft cap on fetched pages."] = 4, same_domain_only: Annotated[bool, "For `crawl_extract`, stay within the starting site's links only."] = True, max_urls: Annotated[int, "For `multi_extract`, maximum URLs allowed in one call."] = 8, max_images: Annotated[int, "For `vision_extract` and image-aware extraction, maximum images to describe."] = 5, max_chars: Annotated[int, "For `render_markdown`, trim returned markdown to this many characters."] = 12000, include_links: Annotated[bool, "For `render_markdown`, include discovered page links."] = True, timeout_s: Annotated[int, "Timeout in seconds passed to ScrapeGraph fetch and generation nodes."] = 30, storage_state_path: Annotated[str, "Optional Playwright storage state JSON path for authenticated pages."] = "", return_debug: Annotated[bool, "Include execution metadata and graph execution info in the response."] = False, ) -> str: _log_call_start( "ScrapeGraphAI", action=action, url=url, urls=urls, prompt=_truncate_for_log(prompt or "", 180), render_mode=render_mode, include_images=include_images, depth=depth, max_pages=max_pages, max_urls=max_urls, max_images=max_images, timeout_s=timeout_s, storage_state_path=storage_state_path, return_debug=return_debug, ) try: _require_scrapegraph() storage_state = _resolve_storage_state(storage_state_path) schema = _coerce_schema(schema_json) schema_model = _schema_to_model(schema) text_model_name = os.getenv(TEXT_MODEL_ENV, DEFAULT_TEXT_MODEL) vision_model_name = os.getenv(VISION_MODEL_ENV, DEFAULT_VISION_MODEL) if action == "render_markdown": if not url.strip(): raise ScrapeGraphToolError("missing_url", "url is required for render_markdown.") final_state, exec_info = _render_markdown_with_fetch( url.strip(), _build_config( api_key=None, render_mode=render_mode, timeout_s=timeout_s, storage_state_path=storage_state, ), ) markdown = (final_state.get("markdown") or "")[: max(1000, int(max_chars))] links, images = _extract_links_and_images(final_state, url.strip()) response = { "action": action, "result": {"markdown": markdown}, "sources": [url.strip()], "artifacts": { "markdown": markdown, "links": links if include_links else [], "images": images if include_images else [], "per_url_results": [], }, "meta": { "render_mode_used": render_mode, "text_model": None, "vision_model": None, }, "warnings": [], } if return_debug: response["debug"] = {"final_state": _json_safe(final_state), "execution_info": _json_safe(exec_info)} result = _json_response(response) _log_call_end("ScrapeGraphAI", _truncate_for_log(result)) return result api_key = _require_mistral_key() if action == "extract": if not url.strip() or not prompt.strip(): raise ScrapeGraphToolError("missing_arguments", "url and prompt are required for extract.") config = _build_config( api_key=api_key, text_model=text_model_name, render_mode=render_mode, timeout_s=timeout_s, storage_state_path=storage_state, max_images=max_images, vision_model=vision_model_name, ) graph_cls = _MistralOmniScraperGraph if include_images else SmartScraperGraph graph = graph_cls(prompt=prompt.strip(), source=url.strip(), config=config, schema=schema_model) result_data = _json_safe(graph.run()) final_state = graph.get_state() response = { "action": action, "result": result_data, "sources": _extract_sources(final_state, [url.strip()]), "artifacts": { "markdown": None, "links": final_state.get("link_urls", []) or [], "images": final_state.get("img_urls", []) or [], "per_url_results": [], }, "meta": { "render_mode_used": render_mode, "text_model": text_model_name, "vision_model": vision_model_name if include_images else None, }, "warnings": [], } if return_debug: response["debug"] = {"final_state": _json_safe(final_state), "execution_info": _json_safe(graph.get_execution_info())} result = _json_response(response) _log_call_end("ScrapeGraphAI", _truncate_for_log(result)) return result if action == "vision_extract": if not url.strip() or not prompt.strip(): raise ScrapeGraphToolError("missing_arguments", "url and prompt are required for vision_extract.") graph = _MistralOmniScraperGraph( prompt=prompt.strip(), source=url.strip(), config=_build_config( api_key=api_key, text_model=text_model_name, render_mode=render_mode, timeout_s=timeout_s, storage_state_path=storage_state, max_images=max_images, vision_model=vision_model_name, ), schema=schema_model, ) result_data = _json_safe(graph.run()) final_state = graph.get_state() img_urls = final_state.get("img_urls", []) or [] if not img_urls: raise ScrapeGraphToolError("no_images_found", "No images were found on the page for vision_extract.") response = { "action": action, "result": result_data, "sources": _extract_sources(final_state, [url.strip()]), "artifacts": { "markdown": None, "links": final_state.get("link_urls", []) or [], "images": img_urls, "per_url_results": [], }, "meta": { "render_mode_used": render_mode, "text_model": text_model_name, "vision_model": vision_model_name, }, "warnings": [], } if return_debug: response["debug"] = {"final_state": _json_safe(final_state), "execution_info": _json_safe(graph.get_execution_info())} result = _json_response(response) _log_call_end("ScrapeGraphAI", _truncate_for_log(result)) return result if action == "multi_extract": normalized_urls = _coerce_urls(urls) if not normalized_urls or not prompt.strip(): raise ScrapeGraphToolError("missing_arguments", "urls and prompt are required for multi_extract.") if len(normalized_urls) > max(1, int(max_urls)): raise ScrapeGraphToolError("too_many_urls", f"multi_extract supports at most {max_urls} URLs per call.") graph = SmartScraperMultiGraph( prompt=prompt.strip(), source=normalized_urls, config=_build_config( api_key=api_key, text_model=text_model_name, render_mode=render_mode, timeout_s=timeout_s, storage_state_path=storage_state, ), schema=schema_model, ) result_data = _json_safe(graph.run()) final_state = graph.get_state() response = { "action": action, "result": result_data, "sources": normalized_urls, "artifacts": { "markdown": None, "links": [], "images": [], "per_url_results": _json_safe(final_state.get("results", [])), }, "meta": { "render_mode_used": render_mode, "text_model": text_model_name, "vision_model": None, }, "warnings": [], } if return_debug: response["debug"] = {"final_state": _json_safe(final_state), "execution_info": _json_safe(graph.get_execution_info())} result = _json_response(response) _log_call_end("ScrapeGraphAI", _truncate_for_log(result)) return result if action == "crawl_extract": if not url.strip() or not prompt.strip(): raise ScrapeGraphToolError("missing_arguments", "url and prompt are required for crawl_extract.") graph = _BoundedDepthSearchGraph( prompt=prompt.strip(), source=url.strip(), config=_build_config( api_key=api_key, text_model=text_model_name, render_mode=render_mode, timeout_s=timeout_s, storage_state_path=storage_state, depth=depth, max_pages=max_pages, same_domain_only=same_domain_only, ), schema=schema_model, ) result_data = _json_safe(graph.run()) final_state = graph.get_state() response = { "action": action, "result": result_data, "sources": _extract_sources(final_state, [url.strip()]), "artifacts": { "markdown": None, "links": [], "images": [], "per_url_results": [], }, "meta": { "render_mode_used": render_mode, "text_model": text_model_name, "vision_model": None, }, "warnings": [], } if return_debug: response["debug"] = {"final_state": _json_safe(final_state), "execution_info": _json_safe(graph.get_execution_info())} result = _json_response(response) _log_call_end("ScrapeGraphAI", _truncate_for_log(result)) return result raise ScrapeGraphToolError("unsupported_action", f"Unsupported action: {action}") except ScrapeGraphToolError as exc: result = _error_response(action, exc.code, exc.message, exc.hint) _log_call_end("ScrapeGraphAI", _truncate_for_log(result)) return result except Exception as exc: # pragma: no cover - runtime integration path code = "browser_unavailable" if "playwright" in str(exc).lower() or "chromium" in str(exc).lower() else "fetch_failed" result = _error_response(action, code, f"ScrapeGraphAI action failed: {exc}") _log_call_end("ScrapeGraphAI", _truncate_for_log(result)) return result def build_interface() -> gr.Interface: return gr.Interface( fn=ScrapeGraphAI, inputs=[ gr.Dropdown(choices=ACTION_CHOICES, value="extract", label="Action"), gr.Textbox(label="URL", placeholder="https://example.com"), gr.JSON(label="URLs", value=[]), gr.Textbox(label="Prompt", lines=4, placeholder="Extract pricing tiers and main limits."), gr.JSON(label="Schema JSON", value={}), gr.Dropdown(choices=RENDER_CHOICES, value="auto", label="Render Mode"), gr.Checkbox(label="Include Images", value=False), gr.Number(label="Depth", value=1, precision=0), gr.Number(label="Max Pages", value=4, precision=0), gr.Checkbox(label="Same Domain Only", value=True), gr.Number(label="Max URLs", value=8, precision=0), gr.Number(label="Max Images", value=5, precision=0), gr.Number(label="Max Chars", value=12000, precision=0), gr.Checkbox(label="Include Links", value=True), gr.Number(label="Timeout (seconds)", value=30, precision=0), gr.Textbox(label="Storage State Path", placeholder="Optional Playwright storage_state JSON path"), gr.Checkbox(label="Return Debug", value=False), ], outputs=gr.Textbox(label="Result", lines=20, max_lines=40), title="ScrapeGraphAI", description="
Mistral-only structured scraping using ScrapeGraphAI graphs.
", api_description=TOOL_SUMMARY, flagging_mode="never", ) __all__ = ["ScrapeGraphAI", "build_interface"]