Spaces:
Sleeping
Sleeping
| """ | |
| Lazy-loading inference wrappers for the 3 trained LoRA agents. | |
| Each model is loaded on first call and cached for the session. | |
| Falls back gracefully to None if GPU unavailable or models fail to load. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import re | |
| import threading | |
| from typing import Optional | |
| # HF Hub model IDs | |
| EXTRACTOR_HUB = "ps2181/extractor-lora-qwen2.5-1.5b" | |
| AUDITOR_HUB = "ps2181/auditor-lora-qwen2.5-1.5b" | |
| GENERATOR_HUB = "ps2181/generator-lora-qwen2.5-1.5b" | |
| BASE_MODEL = "unsloth/Qwen2.5-1.5B-Instruct" | |
| EXTRACTOR_SYSTEM = ( | |
| "Extract invoice fields. Return JSON only: " | |
| "{vendor, date (YYYY-MM-DD), currency (USD/EUR/GBP), total (float), " | |
| "line_items [{description, qty, unit_price, amount}]}" | |
| ) | |
| AUDITOR_SYSTEM = ( | |
| "You are an invoice fraud auditor. Review each invoice and output a JSON array.\n" | |
| "For each invoice output: " | |
| "{\"invoice_id\": \"INV-XXXXX\", \"verdict\": \"approved\" or \"flagged\", " | |
| "\"fraud_type\": null or one of [\"phantom_vendor\",\"price_gouging\"," | |
| "\"math_fraud\",\"duplicate_submission\"], \"confidence\": 0.0-1.0}\n" | |
| "Output ONLY valid JSON: {\"audit_results\": [...]}" | |
| ) | |
| GENERATOR_SYSTEM = ( | |
| "You are an invoice generator. Create a realistic fraudulent invoice as JSON.\n" | |
| "Output ONLY valid JSON with keys: vendor, date, currency, total, line_items, invoice_id." | |
| ) | |
| _lock = threading.Lock() | |
| # Maps name -> (model, tokenizer, device) | None | |
| _cache: dict = {} | |
| _load_errors: dict = {} | |
| def _load(hub_id: str): | |
| """Load base model + LoRA adapter. Returns (model, tokenizer, device) or None.""" | |
| try: | |
| import torch | |
| from peft import PeftModel | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| has_gpu = torch.cuda.is_available() | |
| device = "cuda" if has_gpu else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
| if has_gpu: | |
| bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| else: | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map="cpu", | |
| torch_dtype="auto", | |
| trust_remote_code=True, | |
| ) | |
| model = PeftModel.from_pretrained(base, hub_id) | |
| model.eval() | |
| return model, tokenizer, device | |
| except Exception as exc: | |
| print(f"[agents] Could not load {hub_id}: {exc}") | |
| return None | |
| def _get(name: str, hub_id: str): | |
| with _lock: | |
| if name not in _cache: | |
| print(f"[agents] Loading {name}…") | |
| result = _load(hub_id) | |
| _cache[name] = result | |
| if result is None: | |
| _load_errors[name] = f"load failed for {hub_id}" | |
| return _cache[name] | |
| def _generate(model, tokenizer, device: str, system: str, user: str, max_new_tokens: int = 400) -> str: | |
| import torch | |
| messages = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ] | |
| inputs = tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ) | |
| if device == "cuda": | |
| inputs = inputs.to("cuda") | |
| with torch.no_grad(): | |
| out = model.generate( | |
| inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.3, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| return tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True) | |
| def _parse(text: str): | |
| text = text.strip() | |
| if text.startswith("```"): | |
| text = re.sub(r"^```[a-z]*\n?", "", text) | |
| text = re.sub(r"```$", "", text).strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| return {} | |
| # --------------------------------------------------------------------------- | |
| # Public inference functions | |
| # --------------------------------------------------------------------------- | |
| def run_extractor(raw_text: str, ref_data: str = "") -> tuple[dict, bool]: | |
| """ | |
| Run trained Extractor LoRA on raw invoice text. | |
| Returns (extracted_dict, used_model). | |
| used_model=False means model unavailable, caller should use rule-based fallback. | |
| """ | |
| result = _get("extractor", EXTRACTOR_HUB) | |
| if result is None: | |
| return {}, False | |
| model, tokenizer, device = result | |
| user = (f"REF:\n{ref_data[:200]}\n\nINVOICE:\n{raw_text[:600]}" | |
| if ref_data else raw_text[:600]) | |
| text = _generate(model, tokenizer, device, EXTRACTOR_SYSTEM, user) | |
| parsed = _parse(text) | |
| if isinstance(parsed, list): | |
| parsed = parsed[0] if parsed else {} | |
| return (parsed if isinstance(parsed, dict) else {}), True | |
| def run_auditor(raw_text: str, ref_data: str, n_invoices: int) -> tuple[list, bool]: | |
| """ | |
| Run trained Auditor LoRA on invoice batch text. | |
| Returns (audit_results_list, used_model). | |
| """ | |
| result = _get("auditor", AUDITOR_HUB) | |
| if result is None: | |
| return [], False | |
| model, tokenizer, device = result | |
| user = ( | |
| f"INVOICE BATCH:\n{raw_text[:800]}\n\n" | |
| f"REFERENCE DATA:\n{ref_data[:400]}\n\n" | |
| 'Audit all invoices. Output: {"audit_results": [...]}' | |
| ) | |
| text = _generate(model, tokenizer, device, AUDITOR_SYSTEM, user) | |
| parsed = _parse(text) | |
| if isinstance(parsed, dict): | |
| results = parsed.get("audit_results", []) | |
| elif isinstance(parsed, list): | |
| results = parsed | |
| else: | |
| results = [] | |
| return results, True | |
| def run_generator(fraud_type: str, blind_spots: list | None = None) -> tuple[dict, bool]: | |
| """ | |
| Run trained Generator LoRA to create a fraudulent invoice. | |
| Returns (invoice_dict, used_model). | |
| """ | |
| result = _get("generator", GENERATOR_HUB) | |
| if result is None: | |
| return {}, False | |
| model, tokenizer, device = result | |
| ctx = f"Regulator blind spots: {blind_spots}" if blind_spots else "" | |
| user = ( | |
| f"Generate a realistic invoice with {fraud_type} fraud. {ctx}\n" | |
| "Output JSON only: {{vendor, date, currency, total, line_items, invoice_id}}" | |
| ) | |
| text = _generate(model, tokenizer, device, GENERATOR_SYSTEM, user) | |
| parsed = _parse(text) | |
| return (parsed if isinstance(parsed, dict) else {}), True | |
| def models_status() -> dict[str, str]: | |
| """Return load status for all 3 agents.""" | |
| names = ["extractor", "auditor", "generator"] | |
| status = {} | |
| for n in names: | |
| if n in _cache: | |
| status[n] = "loaded ✅" if _cache[n] is not None else f"failed ❌ ({_load_errors.get(n,'')})" | |
| else: | |
| status[n] = "not loaded yet" | |
| return status | |