Spaces:
Sleeping
Sleeping
Upload backend/core/management/commands/test_legal_training.py with huggingface_hub
Browse files
backend/core/management/commands/test_legal_training.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any, Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
from django.core.management.base import BaseCommand
|
| 8 |
+
|
| 9 |
+
from hue_portal.chatbot.chatbot import get_chatbot
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Command(BaseCommand):
|
| 13 |
+
"""
|
| 14 |
+
Quick smoke-test for legal intent classification & RAG retrieval.
|
| 15 |
+
|
| 16 |
+
This command:
|
| 17 |
+
- loads a sample of generated legal questions from
|
| 18 |
+
backend/hue_portal/chatbot/training/generated_qa/
|
| 19 |
+
- runs the intent classifier on each question
|
| 20 |
+
- (best-effort) calls rag_pipeline with use_llm=False to inspect
|
| 21 |
+
retrieved documents and content_type.
|
| 22 |
+
|
| 23 |
+
It is intended for operators to run occasionally after auto-training
|
| 24 |
+
to verify that:
|
| 25 |
+
- most legal questions are classified as `search_legal`
|
| 26 |
+
- RAG returns legal content for those questions.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
help = "Run a small evaluation of legal intent & RAG using generated QA questions"
|
| 30 |
+
|
| 31 |
+
def add_arguments(self, parser) -> None:
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--max-per-doc",
|
| 34 |
+
type=int,
|
| 35 |
+
default=20,
|
| 36 |
+
help="Maximum number of questions to sample per document JSON file.",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def handle(self, *args: Any, **options: Any) -> None:
|
| 40 |
+
max_per_doc: int = options["max_per_doc"]
|
| 41 |
+
|
| 42 |
+
base_dir = Path(__file__).resolve().parents[4] / "chatbot" / "training" / "generated_qa"
|
| 43 |
+
if not base_dir.exists():
|
| 44 |
+
self.stdout.write(self.style.WARNING(f"No generated QA directory found at {base_dir}"))
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
chatbot = get_chatbot()
|
| 48 |
+
|
| 49 |
+
total = 0
|
| 50 |
+
legal_intent = 0
|
| 51 |
+
other_intent = 0
|
| 52 |
+
|
| 53 |
+
# Optional RAG import
|
| 54 |
+
try:
|
| 55 |
+
from hue_portal.core.rag import rag_pipeline # type: ignore
|
| 56 |
+
except Exception:
|
| 57 |
+
rag_pipeline = None # type: ignore
|
| 58 |
+
|
| 59 |
+
self.stdout.write(self.style.MIGRATE_HEADING("Evaluating legal intent & RAG on generated QA..."))
|
| 60 |
+
|
| 61 |
+
for path in sorted(base_dir.glob("*.json")):
|
| 62 |
+
try:
|
| 63 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 64 |
+
except Exception:
|
| 65 |
+
self.stdout.write(self.style.WARNING(f"Skipping malformed QA file: {path.name}"))
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
if not isinstance(payload, list):
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
self.stdout.write(self.style.HTTP_INFO(f"File: {path.name}"))
|
| 72 |
+
|
| 73 |
+
for item in payload[:max_per_doc]:
|
| 74 |
+
if not isinstance(item, dict):
|
| 75 |
+
continue
|
| 76 |
+
question = str(item.get("question") or "").strip()
|
| 77 |
+
if not question:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
intent, confidence = chatbot.classify_intent(question)
|
| 81 |
+
total += 1
|
| 82 |
+
if intent == "search_legal":
|
| 83 |
+
legal_intent += 1
|
| 84 |
+
else:
|
| 85 |
+
other_intent += 1
|
| 86 |
+
|
| 87 |
+
rag_info: Tuple[str, int] = ("n/a", 0)
|
| 88 |
+
if rag_pipeline is not None:
|
| 89 |
+
try:
|
| 90 |
+
rag_result: Dict[str, Any] = rag_pipeline(
|
| 91 |
+
question,
|
| 92 |
+
intent,
|
| 93 |
+
top_k=3,
|
| 94 |
+
min_confidence=confidence,
|
| 95 |
+
context=None,
|
| 96 |
+
use_llm=False,
|
| 97 |
+
)
|
| 98 |
+
rag_info = (
|
| 99 |
+
str(rag_result.get("content_type") or "n/a"),
|
| 100 |
+
int(rag_result.get("count") or 0),
|
| 101 |
+
)
|
| 102 |
+
except Exception:
|
| 103 |
+
rag_info = ("error", 0)
|
| 104 |
+
|
| 105 |
+
self.stdout.write(
|
| 106 |
+
f"- Q: {question[:80]}... | intent={intent} ({confidence:.2f}) "
|
| 107 |
+
f"| RAG type={rag_info[0]} count={rag_info[1]}"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.stdout.write("")
|
| 111 |
+
if total == 0:
|
| 112 |
+
self.stdout.write(self.style.WARNING("No questions evaluated."))
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
pct_legal = (legal_intent / total) * 100.0
|
| 116 |
+
self.stdout.write(
|
| 117 |
+
self.style.SUCCESS(
|
| 118 |
+
f"Total questions: {total} | search_legal: {legal_intent} ({pct_legal:.1f}%) "
|
| 119 |
+
f"| other intents: {other_intent}"
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|