ravimohan19's picture
Upload graph.py with huggingface_hub
3856e1d verified
"""
LangGraph workflow for the Polymer Datasheet Crawler Agent.
Workflow:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ router β”‚ ── decides search vs upload path
β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜
β”‚
β”Œβ”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ web_search β”‚ β”‚ process_upload β”‚
β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”
β”‚ llm_parse β”‚ ── calls LLaMA 3.1 to extract properties
β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”
β”‚ store_db β”‚ ── persists to SQLite
β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”
β”‚ finalize β”‚ ── formats output
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
from __future__ import annotations
import logging
from typing import Any, Literal
from langgraph.graph import END, StateGraph
from database import DatasheetDB
from llm_parser import parse_datasheet, parse_uploaded_text
from models import AgentState
from web_crawler import search_datasheets, _pick_best_source_url
logger = logging.getLogger(__name__)
# ── Shared DB instance ───────────────────────────────────────────────────────
db = DatasheetDB()
# ── Node functions ───────────────────────────────────────────────────────────
def router_node(state: dict[str, Any]) -> dict[str, Any]:
"""Determine whether we're doing a web search or processing an upload."""
input_mode = state.get("input_mode", "search")
logger.info("Router: mode=%s", input_mode)
return {"input_mode": input_mode}
def web_search_node(state: dict[str, Any]) -> dict[str, Any]:
"""Execute Tavily web search for polymer datasheets."""
manufacturer = state.get("manufacturer", "")
polymer_family = state.get("polymer_family", "")
grade = state.get("grade", "")
logger.info(
"Web search: manufacturer=%s, polymer=%s, grade=%s",
manufacturer, polymer_family, grade,
)
results, raw_content = search_datasheets(
manufacturer=manufacturer,
polymer_family=polymer_family,
grade=grade,
)
# Pick the best non-PDF source URL
source_url = _pick_best_source_url(results) if results else ""
return {
"search_results": results,
"raw_content": raw_content,
"source_url": source_url,
"status": "searched" if raw_content else "no_results",
"message": (
f"Found {len(results)} sources with {len(raw_content)} chars of content."
if raw_content
else "No relevant datasheets found in web search."
),
}
def process_upload_node(state: dict[str, Any]) -> dict[str, Any]:
"""Process user-uploaded datasheet text."""
uploaded_text = state.get("uploaded_text", "")
if not uploaded_text.strip():
return {
"status": "error",
"message": "No text found in uploaded file.",
}
return {
"raw_content": uploaded_text,
"status": "uploaded",
"message": f"Uploaded text: {len(uploaded_text)} chars ready for parsing.",
}
def llm_parse_node(state: dict[str, Any]) -> dict[str, Any]:
"""Call LLaMA 3.1 to extract structured properties from raw content."""
raw_content = state.get("raw_content", "")
if not raw_content.strip():
return {
"status": "error",
"message": "No content available for LLM parsing.",
"parsing_errors": ["Empty raw content"],
}
manufacturer = state.get("manufacturer", "")
polymer_family = state.get("polymer_family", "")
grade = state.get("grade", "")
source_url = state.get("source_url", "")
logger.info("LLM parsing %d chars of raw content...", len(raw_content))
record, errors = parse_datasheet(
raw_content=raw_content,
manufacturer=manufacturer,
polymer_family=polymer_family,
grade=grade,
source_url=source_url,
)
if record:
return {
"parsed_datasheet": record.model_dump(),
"parsing_errors": errors,
"status": "parsed",
"message": f"Successfully extracted datasheet for {record.trade_name or record.material_name}.",
}
else:
return {
"parsing_errors": errors,
"status": "parse_failed",
"message": f"Failed to parse datasheet: {'; '.join(errors)}",
}
def store_db_node(state: dict[str, Any]) -> dict[str, Any]:
"""Store the parsed datasheet in the SQLite database."""
parsed = state.get("parsed_datasheet")
if not parsed:
return {
"status": "error",
"message": "No parsed datasheet to store.",
}
from models import DatasheetRecord
record = DatasheetRecord(**parsed)
record_id = db.upsert(record)
count = db.count()
return {
"status": "stored",
"message": (
f"Stored datasheet '{record.trade_name or record.material_name}' "
f"(ID: {record_id}). Database now has {count} records."
),
}
def finalize_node(state: dict[str, Any]) -> dict[str, Any]:
"""Final node β€” consolidates the output message."""
status = state.get("status", "unknown")
message = state.get("message", "")
if status in ("stored",):
return {"status": "success", "message": message}
elif status in ("error", "parse_failed", "no_results"):
return {"status": "failed", "message": message}
else:
return {"status": status, "message": message}
# ── Conditional edges ────────────────────────────────────────────────────────
def route_by_mode(state: dict[str, Any]) -> Literal["web_search", "process_upload"]:
"""Route to search or upload branch based on input_mode."""
if state.get("input_mode") == "upload":
return "process_upload"
return "web_search"
def route_after_content(state: dict[str, Any]) -> Literal["llm_parse", "finalize"]:
"""Skip LLM parsing if no content was found."""
status = state.get("status", "")
if status in ("no_results", "error"):
return "finalize"
return "llm_parse"
def route_after_parse(state: dict[str, Any]) -> Literal["store_db", "finalize"]:
"""Skip DB storage if parsing failed."""
if state.get("parsed_datasheet"):
return "store_db"
return "finalize"
# ── Build the graph ──────────────────────────────────────────────────────────
def build_graph() -> StateGraph:
"""Construct and compile the LangGraph workflow."""
workflow = StateGraph(dict)
# Add nodes
workflow.add_node("router", router_node)
workflow.add_node("web_search", web_search_node)
workflow.add_node("process_upload", process_upload_node)
workflow.add_node("llm_parse", llm_parse_node)
workflow.add_node("store_db", store_db_node)
workflow.add_node("finalize", finalize_node)
# Set entry point
workflow.set_entry_point("router")
# Router β†’ search or upload
workflow.add_conditional_edges(
"router",
route_by_mode,
{
"web_search": "web_search",
"process_upload": "process_upload",
},
)
# After content acquisition β†’ parse or finalize
workflow.add_conditional_edges(
"web_search",
route_after_content,
{"llm_parse": "llm_parse", "finalize": "finalize"},
)
workflow.add_conditional_edges(
"process_upload",
route_after_content,
{"llm_parse": "llm_parse", "finalize": "finalize"},
)
# After parsing β†’ store or finalize
workflow.add_conditional_edges(
"llm_parse",
route_after_parse,
{"store_db": "store_db", "finalize": "finalize"},
)
# store_db β†’ finalize β†’ END
workflow.add_edge("store_db", "finalize")
workflow.add_edge("finalize", END)
return workflow.compile()
# ── Convenience runners ──────────────────────────────────────────────────────
def run_search(
manufacturer: str,
polymer_family: str,
grade: str = "",
) -> dict[str, Any]:
"""Run the full workflow in search mode."""
graph = build_graph()
initial_state = {
"input_mode": "search",
"manufacturer": manufacturer,
"polymer_family": polymer_family,
"grade": grade,
}
result = graph.invoke(initial_state)
return result
def run_upload(uploaded_text: str) -> dict[str, Any]:
"""Run the full workflow in upload mode."""
graph = build_graph()
initial_state = {
"input_mode": "upload",
"uploaded_text": uploaded_text,
}
result = graph.invoke(initial_state)
return result
def search_database(
query: str = "",
manufacturer: str = "",
polymer_family: str = "",
) -> Any:
"""Search the existing database."""
return db.search(
query=query,
manufacturer=manufacturer,
polymer_family=polymer_family,
)
def get_database_summary() -> Any:
"""Get summary of all records in the database."""
return db.get_summary_dataframe()