MalikShehram's picture
Upload 9 files
8cadb90 verified
"""
main.py β€” Support Triage Agent (Groq + Hugging Face)
=====================================================
Usage:
python main.py # batch mode (reads support_issues.csv)
python main.py --input path/to/issues.csv # custom input
python main.py --interactive # chat with the agent in terminal
python main.py --interactive --input issues.csv # process CSV then go interactive
Outputs:
output.csv β€” predictions for submission
log.txt β€” full session transcript for submission
"""
import os
import sys
import csv
import argparse
import datetime
sys.path.insert(0, os.path.dirname(__file__))
from corpus_builder import build_corpus
from retriever import MultiDomainRetriever
from safety import should_escalate
from agent import classify_ticket, generate_response, generate_escalation_message
# ── Paths ─────────────────────────────────────────────────────────────────────
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
INPUT_CSV = os.path.join(BASE_DIR, "data", "support_issues", "support_issues.csv")
OUTPUT_CSV = os.path.join(BASE_DIR, "output.csv")
LOG_FILE = os.path.join(BASE_DIR, "log.txt")
OUTPUT_FIELDS = [
"ticket_id", "domain", "request_type", "product_area",
"action", "escalation_reason", "response",
]
BANNER = """
╔══════════════════════════════════════════════════════════╗
β•‘ SUPPORT TRIAGE AGENT β€” Groq + LLaMA 3 β•‘
β•‘ HackerRank | Claude | Visa multi-domain support β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
# ── Logger ────────────────────────────────────────────────────────────────────
class Logger:
def __init__(self, path: str):
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
self.f = open(path, "w", encoding="utf-8")
self.log(BANNER)
self.log(f"Session started: {datetime.datetime.now().isoformat()}\n")
def log(self, msg: str):
print(msg)
self.f.write(msg + "\n")
self.f.flush()
def close(self):
self.log(f"\nSession ended: {datetime.datetime.now().isoformat()}")
self.f.close()
# ── Core pipeline ─────────────────────────────────────────────────────────────
def process_ticket(
ticket_id: str,
ticket_text: str,
retriever: MultiDomainRetriever,
logger: Logger,
) -> dict:
"""Run full triage pipeline for one ticket."""
logger.log(f"\n{'─'*60}")
logger.log(f"TICKET : #{ticket_id}")
logger.log(f"ISSUE : {ticket_text[:220]}{'...' if len(ticket_text) > 220 else ''}")
# Step 1 β€” Classify
clf = classify_ticket(ticket_text)
domain = clf.get("domain", "unknown")
request_type = clf.get("request_type", "other")
product_area = clf.get("product_area", "general")
confidence = clf.get("confidence", "low")
logger.log(f"CLASSIFY : domain={domain} | type={request_type} | area={product_area} | conf={confidence}")
# Step 2 β€” Retrieve
if domain in ("hackerrank", "claude", "visa"):
docs = retriever.retrieve_for_domain(domain, ticket_text, top_k=4)
# Supplement with global search if domain results are weak
best = max((d.get("score", 0) for d in docs), default=0)
if best < 1.5:
extras = retriever.retrieve_all(ticket_text, top_k_per_domain=1)
seen = {d["url"] for d in docs}
for d in extras:
if d["url"] not in seen:
docs.append(d)
else:
docs = retriever.retrieve_all(ticket_text, top_k_per_domain=2)
best_score = max((d.get("score", 0) for d in docs), default=0)
logger.log(f"RETRIEVAL : {len(docs)} docs | best score={best_score:.2f}")
for d in docs[:2]:
logger.log(f" β†’ [{d.get('score',0):.2f}] {d['title'][:65]}")
# Step 3 β€” Safety gate
escalate, esc_reason = should_escalate(ticket_text, product_area, docs)
logger.log(f"SAFETY : {'ESCALATE ⚠' if escalate else 'RESPOND βœ“'}"
+ (f" | {esc_reason}" if escalate else ""))
# Step 4 β€” Generate
if escalate:
action = "escalate"
response = generate_escalation_message(ticket_text, esc_reason)
else:
action = "respond"
esc_reason = ""
response = generate_response(ticket_text, docs)
logger.log(f"RESPONSE :\n{response}")
return {
"ticket_id": ticket_id,
"domain": domain,
"request_type": request_type,
"product_area": product_area,
"action": action,
"escalation_reason": esc_reason,
"response": response,
}
# ── Batch mode ────────────────────────────────────────────────────────────────
def run_batch(input_csv: str, retriever: MultiDomainRetriever, logger: Logger) -> list[dict]:
if not os.path.exists(input_csv):
logger.log(f"[WARN] Input CSV not found: {input_csv}")
logger.log(" Run with --interactive to test manually.")
return []
with open(input_csv, newline="", encoding="utf-8") as f:
rows = list(csv.DictReader(f))
logger.log(f"\nProcessing {len(rows)} tickets from: {input_csv}")
results = []
for i, row in enumerate(rows, 1):
tid = row.get("ticket_id") or row.get("id") or str(i)
text = (row.get("issue") or row.get("text") or
row.get("description") or row.get("message") or "").strip()
if not text:
continue
logger.log(f"\n[{i}/{len(rows)}]")
result = process_ticket(tid, text, retriever, logger)
results.append(result)
return results
def write_output(results: list[dict], output_csv: str):
with open(output_csv, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=OUTPUT_FIELDS)
writer.writeheader()
writer.writerows(results)
print(f"\nβœ… output.csv written β†’ {output_csv} ({len(results)} rows)")
# ── Interactive mode ──────────────────────────────────────────────────────────
def run_interactive(retriever: MultiDomainRetriever, logger: Logger):
logger.log("\n" + "═"*60)
logger.log("INTERACTIVE MODE β€” type your support issue, 'quit' to exit")
logger.log("═"*60 + "\n")
ticket_id = 1
while True:
try:
text = input("You > ").strip()
except (EOFError, KeyboardInterrupt):
break
if text.lower() in ("quit", "exit", "q", ""):
break
result = process_ticket(str(ticket_id), text, retriever, logger)
print(f"\nAgent > {result['response']}\n")
ticket_id += 1
# ── Entry point ───────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Support Triage Agent β€” Groq + BM25")
parser.add_argument("--input", default=INPUT_CSV, help="Path to input support_issues.csv")
parser.add_argument("--output", default=OUTPUT_CSV, help="Path for output.csv")
parser.add_argument("--log", default=LOG_FILE, help="Path for log.txt")
parser.add_argument("--interactive", action="store_true", help="Run interactive terminal mode")
args = parser.parse_args()
logger = Logger(args.log)
# 1. Load corpus
logger.log("[1/3] Building corpus...")
corpus = build_corpus()
# 2. Build retriever
logger.log("[2/3] Building BM25 index...")
retriever = MultiDomainRetriever(corpus)
logger.log(" Index ready.\n")
# 3. Process
logger.log("[3/3] Running triage agent...")
results = run_batch(args.input, retriever, logger)
if results:
write_output(results, args.output)
if args.interactive:
run_interactive(retriever, logger)
logger.log("\n" + "═"*60)
logger.log("SESSION COMPLETE")
logger.close()
if __name__ == "__main__":
main()