Spaces:
Runtime error
Runtime error
| """ | |
| main.py β Support Triage Agent (Groq + Hugging Face) | |
| ===================================================== | |
| Usage: | |
| python main.py # batch mode (reads support_issues.csv) | |
| python main.py --input path/to/issues.csv # custom input | |
| python main.py --interactive # chat with the agent in terminal | |
| python main.py --interactive --input issues.csv # process CSV then go interactive | |
| Outputs: | |
| output.csv β predictions for submission | |
| log.txt β full session transcript for submission | |
| """ | |
| import os | |
| import sys | |
| import csv | |
| import argparse | |
| import datetime | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from corpus_builder import build_corpus | |
| from retriever import MultiDomainRetriever | |
| from safety import should_escalate | |
| from agent import classify_ticket, generate_response, generate_escalation_message | |
| # ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| INPUT_CSV = os.path.join(BASE_DIR, "data", "support_issues", "support_issues.csv") | |
| OUTPUT_CSV = os.path.join(BASE_DIR, "output.csv") | |
| LOG_FILE = os.path.join(BASE_DIR, "log.txt") | |
| OUTPUT_FIELDS = [ | |
| "ticket_id", "domain", "request_type", "product_area", | |
| "action", "escalation_reason", "response", | |
| ] | |
| BANNER = """ | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β SUPPORT TRIAGE AGENT β Groq + LLaMA 3 β | |
| β HackerRank | Claude | Visa multi-domain support β | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| # ββ Logger ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class Logger: | |
| def __init__(self, path: str): | |
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True) | |
| self.f = open(path, "w", encoding="utf-8") | |
| self.log(BANNER) | |
| self.log(f"Session started: {datetime.datetime.now().isoformat()}\n") | |
| def log(self, msg: str): | |
| print(msg) | |
| self.f.write(msg + "\n") | |
| self.f.flush() | |
| def close(self): | |
| self.log(f"\nSession ended: {datetime.datetime.now().isoformat()}") | |
| self.f.close() | |
| # ββ Core pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_ticket( | |
| ticket_id: str, | |
| ticket_text: str, | |
| retriever: MultiDomainRetriever, | |
| logger: Logger, | |
| ) -> dict: | |
| """Run full triage pipeline for one ticket.""" | |
| logger.log(f"\n{'β'*60}") | |
| logger.log(f"TICKET : #{ticket_id}") | |
| logger.log(f"ISSUE : {ticket_text[:220]}{'...' if len(ticket_text) > 220 else ''}") | |
| # Step 1 β Classify | |
| clf = classify_ticket(ticket_text) | |
| domain = clf.get("domain", "unknown") | |
| request_type = clf.get("request_type", "other") | |
| product_area = clf.get("product_area", "general") | |
| confidence = clf.get("confidence", "low") | |
| logger.log(f"CLASSIFY : domain={domain} | type={request_type} | area={product_area} | conf={confidence}") | |
| # Step 2 β Retrieve | |
| if domain in ("hackerrank", "claude", "visa"): | |
| docs = retriever.retrieve_for_domain(domain, ticket_text, top_k=4) | |
| # Supplement with global search if domain results are weak | |
| best = max((d.get("score", 0) for d in docs), default=0) | |
| if best < 1.5: | |
| extras = retriever.retrieve_all(ticket_text, top_k_per_domain=1) | |
| seen = {d["url"] for d in docs} | |
| for d in extras: | |
| if d["url"] not in seen: | |
| docs.append(d) | |
| else: | |
| docs = retriever.retrieve_all(ticket_text, top_k_per_domain=2) | |
| best_score = max((d.get("score", 0) for d in docs), default=0) | |
| logger.log(f"RETRIEVAL : {len(docs)} docs | best score={best_score:.2f}") | |
| for d in docs[:2]: | |
| logger.log(f" β [{d.get('score',0):.2f}] {d['title'][:65]}") | |
| # Step 3 β Safety gate | |
| escalate, esc_reason = should_escalate(ticket_text, product_area, docs) | |
| logger.log(f"SAFETY : {'ESCALATE β ' if escalate else 'RESPOND β'}" | |
| + (f" | {esc_reason}" if escalate else "")) | |
| # Step 4 β Generate | |
| if escalate: | |
| action = "escalate" | |
| response = generate_escalation_message(ticket_text, esc_reason) | |
| else: | |
| action = "respond" | |
| esc_reason = "" | |
| response = generate_response(ticket_text, docs) | |
| logger.log(f"RESPONSE :\n{response}") | |
| return { | |
| "ticket_id": ticket_id, | |
| "domain": domain, | |
| "request_type": request_type, | |
| "product_area": product_area, | |
| "action": action, | |
| "escalation_reason": esc_reason, | |
| "response": response, | |
| } | |
| # ββ Batch mode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_batch(input_csv: str, retriever: MultiDomainRetriever, logger: Logger) -> list[dict]: | |
| if not os.path.exists(input_csv): | |
| logger.log(f"[WARN] Input CSV not found: {input_csv}") | |
| logger.log(" Run with --interactive to test manually.") | |
| return [] | |
| with open(input_csv, newline="", encoding="utf-8") as f: | |
| rows = list(csv.DictReader(f)) | |
| logger.log(f"\nProcessing {len(rows)} tickets from: {input_csv}") | |
| results = [] | |
| for i, row in enumerate(rows, 1): | |
| tid = row.get("ticket_id") or row.get("id") or str(i) | |
| text = (row.get("issue") or row.get("text") or | |
| row.get("description") or row.get("message") or "").strip() | |
| if not text: | |
| continue | |
| logger.log(f"\n[{i}/{len(rows)}]") | |
| result = process_ticket(tid, text, retriever, logger) | |
| results.append(result) | |
| return results | |
| def write_output(results: list[dict], output_csv: str): | |
| with open(output_csv, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=OUTPUT_FIELDS) | |
| writer.writeheader() | |
| writer.writerows(results) | |
| print(f"\nβ output.csv written β {output_csv} ({len(results)} rows)") | |
| # ββ Interactive mode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_interactive(retriever: MultiDomainRetriever, logger: Logger): | |
| logger.log("\n" + "β"*60) | |
| logger.log("INTERACTIVE MODE β type your support issue, 'quit' to exit") | |
| logger.log("β"*60 + "\n") | |
| ticket_id = 1 | |
| while True: | |
| try: | |
| text = input("You > ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| break | |
| if text.lower() in ("quit", "exit", "q", ""): | |
| break | |
| result = process_ticket(str(ticket_id), text, retriever, logger) | |
| print(f"\nAgent > {result['response']}\n") | |
| ticket_id += 1 | |
| # ββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Support Triage Agent β Groq + BM25") | |
| parser.add_argument("--input", default=INPUT_CSV, help="Path to input support_issues.csv") | |
| parser.add_argument("--output", default=OUTPUT_CSV, help="Path for output.csv") | |
| parser.add_argument("--log", default=LOG_FILE, help="Path for log.txt") | |
| parser.add_argument("--interactive", action="store_true", help="Run interactive terminal mode") | |
| args = parser.parse_args() | |
| logger = Logger(args.log) | |
| # 1. Load corpus | |
| logger.log("[1/3] Building corpus...") | |
| corpus = build_corpus() | |
| # 2. Build retriever | |
| logger.log("[2/3] Building BM25 index...") | |
| retriever = MultiDomainRetriever(corpus) | |
| logger.log(" Index ready.\n") | |
| # 3. Process | |
| logger.log("[3/3] Running triage agent...") | |
| results = run_batch(args.input, retriever, logger) | |
| if results: | |
| write_output(results, args.output) | |
| if args.interactive: | |
| run_interactive(retriever, logger) | |
| logger.log("\n" + "β"*60) | |
| logger.log("SESSION COMPLETE") | |
| logger.close() | |
| if __name__ == "__main__": | |
| main() | |