Spaces:
Sleeping
Sleeping
| # app.py — MCP server (single-file) | |
| from mcp.server.fastmcp import FastMCP | |
| from typing import Optional, List, Tuple, Any, Dict | |
| import requests | |
| import os | |
| import gradio as gr | |
| import json | |
| import time | |
| import re | |
| import logging | |
| import asyncio | |
| import gc | |
| import shutil | |
| # --- Import OCR Engine & Prompts --- | |
| try: | |
| from ocr_engine import extract_text_from_file | |
| from prompts import get_ocr_extraction_prompt, get_agent_prompt | |
| except ImportError: | |
| def extract_text_from_file(path): return "" | |
| def get_ocr_extraction_prompt(txt): return txt | |
| def get_agent_prompt(h, c, u): return u | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("mcp_server") | |
| # --- Load Config --- | |
| try: | |
| from config import ( | |
| CLIENT_ID, CLIENT_SECRET, REFRESH_TOKEN, API_BASE, | |
| INVOICE_API_BASE, ORGANIZATION_ID, LOCAL_MODEL | |
| ) | |
| except Exception: | |
| raise SystemExit("Config missing.") | |
| mcp = FastMCP("ZohoCRMAgent") | |
| # --- Globals --- | |
| LLM_PIPELINE = None | |
| TOKENIZER = None | |
| # --- Helpers --- | |
| def extract_json_safely(text: str) -> Optional[Any]: | |
| try: | |
| return json.loads(text) | |
| except: | |
| match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL) | |
| return json.loads(match.group(0)) if match else None | |
| def _normalize_local_path_args(args: Any) -> Any: | |
| if not isinstance(args, dict): return args | |
| fp = args.get("file_path") or args.get("path") | |
| if isinstance(fp, str) and fp.startswith("/mnt/data/") and os.path.exists(fp): | |
| args["file_url"] = f"file://{fp}" | |
| return args | |
| # --- Model Loading --- | |
| def init_local_model(): | |
| global LLM_PIPELINE, TOKENIZER | |
| if LLM_PIPELINE is not None: return | |
| try: | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| logger.info(f"Loading lighter model: {LOCAL_MODEL}...") | |
| TOKENIZER = AutoTokenizer.from_pretrained(LOCAL_MODEL) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| LOCAL_MODEL, | |
| device_map="auto", | |
| torch_dtype="auto" | |
| ) | |
| LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=TOKENIZER) | |
| logger.info("Model loaded.") | |
| except Exception as e: | |
| logger.error(f"Model load error: {e}") | |
| def local_llm_generate(prompt: str, max_tokens: int = 512) -> Dict[str, Any]: | |
| if LLM_PIPELINE is None: | |
| init_local_model() | |
| if LLM_PIPELINE is None: | |
| return {"text": "Model not loaded.", "raw": None} | |
| try: | |
| # FIX: Removed invalid flags 'temperature', 'top_p', etc. when do_sample is False | |
| out = LLM_PIPELINE( | |
| prompt, | |
| max_new_tokens=max_tokens, | |
| return_full_text=False, | |
| do_sample=False # Deterministic | |
| ) | |
| text = out[0]["generated_text"] if out else "" | |
| return {"text": text, "raw": out} | |
| except Exception as e: | |
| return {"text": f"Error: {e}", "raw": None} | |
| # --- Tools (Zoho) --- | |
| def _get_valid_token_headers() -> dict: | |
| r = requests.post("https://accounts.zoho.in/oauth/v2/token", params={ | |
| "refresh_token": REFRESH_TOKEN, "client_id": CLIENT_ID, | |
| "client_secret": CLIENT_SECRET, "grant_type": "refresh_token" | |
| }, timeout=10) | |
| if r.status_code == 200: | |
| return {"Authorization": f"Zoho-oauthtoken {r.json().get('access_token')}"} | |
| return {} | |
| def create_record(module_name: str, record_data: dict) -> str: | |
| h = _get_valid_token_headers() | |
| if not h: return "Auth Failed" | |
| r = requests.post(f"{API_BASE}/{module_name}", headers=h, json={"data": [record_data]}) | |
| if r.status_code in (200, 201): | |
| try: | |
| d = r.json().get("data", [{}])[0].get("details", {}) | |
| return json.dumps({"status": "success", "id": d.get("id"), "response": r.json()}) | |
| except: | |
| return json.dumps(r.json()) | |
| return r.text | |
| def create_invoice(data: dict) -> str: | |
| h = _get_valid_token_headers() | |
| if not h: return "Auth Failed" | |
| r = requests.post(f"{INVOICE_API_BASE}/invoices", headers=h, | |
| params={"organization_id": ORGANIZATION_ID}, json=data) | |
| return json.dumps(r.json()) if r.status_code in (200, 201) else r.text | |
| def process_document(file_path: str, target_module: Optional[str] = "Contacts") -> dict: | |
| if not os.path.exists(file_path): | |
| logger.error(f"process_document: File not found at {file_path}") | |
| return {"error": f"File not found at path: {file_path}"} | |
| # 1. OCR | |
| raw_text = extract_text_from_file(file_path) | |
| if not raw_text: return {"error": "OCR empty"} | |
| # 2. LLM Extraction | |
| prompt = get_ocr_extraction_prompt(raw_text) | |
| res = local_llm_generate(prompt, max_tokens=300) | |
| data = extract_json_safely(res["text"]) | |
| return { | |
| "status": "success", | |
| "file": os.path.basename(file_path), | |
| "extracted_data": data or {"raw": res["text"]} | |
| } | |
| # --- Executor --- | |
| def parse_and_execute(model_text: str, history: list) -> str: | |
| payload = extract_json_safely(model_text) | |
| if not payload: return "No valid tool call found." | |
| cmds = [payload] if isinstance(payload, dict) else payload | |
| results = [] | |
| last_contact_id = None | |
| for cmd in cmds: | |
| if not isinstance(cmd, dict): continue | |
| tool = cmd.get("tool") | |
| args = _normalize_local_path_args(cmd.get("args", {})) | |
| if tool == "create_record": | |
| res = create_record(args.get("module", "Contacts"), args) | |
| results.append(f"Record: {res}") | |
| try: | |
| rj = json.loads(res) | |
| if isinstance(rj, dict) and "id" in rj: | |
| last_contact_id = rj["id"] | |
| except: pass | |
| elif tool == "create_invoice": | |
| if not args.get("customer_id") and last_contact_id: | |
| args["customer_id"] = last_contact_id | |
| items = [] | |
| for it in args.get("line_items", []): | |
| items.append({ | |
| "name": it.get("name", "Item"), | |
| "rate": float(str(it.get("rate", 0)).replace("$", "")), | |
| "quantity": int(it.get("quantity", 1)) | |
| }) | |
| payload = {"customer_id": args.get("customer_id"), "line_items": items} | |
| if args.get("currency"): payload["currency_code"] = args["currency"] | |
| res = create_invoice(payload) | |
| results.append(f"Invoice: {res}") | |
| elif tool == "process_document": | |
| # NOTE: Prompts try to prevent this, but if it happens, we rely on args being correct | |
| res = process_document(args.get("file_path")) | |
| results.append(f"Processed: {res}") | |
| return "\n".join(results) | |
| # --- Chat Core --- | |
| def chat_logic(message: str, file_path: str, history: list) -> str: | |
| # 1. Ingest File IMMEDIATELY | |
| file_context = "" | |
| if file_path: | |
| logger.info(f"Ingesting file from path: {file_path}") | |
| doc = process_document(file_path) | |
| if doc.get("status") == "success": | |
| file_context = json.dumps(doc["extracted_data"]) | |
| if not message: message = "Create records from this file." | |
| else: | |
| return f"OCR Failed: {doc}" | |
| # 2. Decision Prompt (With context injected) | |
| hist_txt = "\n".join([f"U: {h[0]}\nA: {h[1]}" for h in history]) | |
| prompt = get_agent_prompt(hist_txt, file_context, message) | |
| # 3. Gen & Execute | |
| gen = local_llm_generate(prompt, max_tokens=200) | |
| logger.info(f"LLM Decision: {gen['text']}") | |
| tool_data = extract_json_safely(gen["text"]) | |
| if tool_data: | |
| return parse_and_execute(gen["text"], history) | |
| return gen["text"] | |
| # --- UI --- | |
| def chat_handler(msg, hist): | |
| txt = msg.get("text", "") | |
| files = msg.get("files", []) | |
| path = files[0] if files else None | |
| if path: | |
| logger.info(f"UI received file: {path}") | |
| # Direct path bypass for debugging | |
| if not path and txt.startswith("/mnt/data"): | |
| return str(process_document(txt)) | |
| return chat_logic(txt, path, hist) | |
| if __name__ == "__main__": | |
| gc.collect() | |
| demo = gr.ChatInterface(fn=chat_handler, multimodal=True) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |