Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Command-line interface for NLProxy compression operations. | |
| Enterprise-ready, async-compatible, and fully tested. | |
| Author: IntelliDeep Labs Team | |
| License: BSL 1.1 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import textwrap | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| try: | |
| from rich.console import Console | |
| from rich.table import Table | |
| _RICH_AVAILABLE = True | |
| except ImportError: | |
| _RICH_AVAILABLE = False | |
| Console = None | |
| from nlproxy.core.restriction import Restriction | |
| from nlproxy.service.compression import CompressionService | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # CLI UTILITIES | |
| # ============================================================================= | |
| def setup_logging(level: str = "INFO") -> None: | |
| numeric_level = getattr(logging, level.upper(), logging.INFO) | |
| logging.basicConfig( | |
| level=numeric_level, | |
| format="%(asctime)s [%(levelname)-8s] %(message)s", | |
| datefmt="%H:%M:%S", | |
| stream=sys.stderr, | |
| ) | |
| def print_output(data: Any, format: str = "text", output_file: Optional[str] = None, quiet: bool = False) -> None: | |
| if quiet and format == "text": | |
| return | |
| dest = sys.stdout | |
| if output_file and output_file != "-": | |
| dest = open(output_file, "w", encoding="utf-8") | |
| try: | |
| if format == "json": | |
| json.dump(data, dest, indent=2, default=str) | |
| dest.write("\n") | |
| elif format == "jsonl": | |
| if isinstance(data, list): | |
| for item in data: dest.write(json.dumps(item, default=str) + "\n") | |
| else: | |
| dest.write(json.dumps(data, default=str) + "\n") | |
| else: | |
| if isinstance(data, dict): | |
| for k, v in data.items(): dest.write(f"{k}: {v}\n") | |
| elif isinstance(data, list): | |
| for item in data: dest.write(f"{item}\n") | |
| else: | |
| dest.write(str(data) + ("\n" if not str(data).endswith("\n") else "")) | |
| finally: | |
| if output_file and output_file != "-": | |
| dest.close() | |
| def error_exit(message: str, code: int = 1, quiet: bool = False) -> None: | |
| if not quiet: | |
| if _RICH_AVAILABLE and Console: Console(stderr=True).print(f"[bold red]Error:[/bold red] {message}") | |
| else: print(f"Error: {message}", file=sys.stderr) | |
| sys.exit(code) | |
| def success_exit(message: Optional[str] = None, quiet: bool = False) -> None: | |
| if message and not quiet: | |
| if _RICH_AVAILABLE and Console: Console().print(f"[bold green]✓[/bold green] {message}") | |
| else: print(f"✓ {message}") | |
| sys.exit(0) | |
| # ============================================================================= | |
| # COMMAND: compress | |
| # ============================================================================= | |
| def cmd_compress(args: argparse.Namespace) -> None: | |
| setup_logging(args.log_level) | |
| input_texts: List[str] = [] | |
| # 1. Resolve Input | |
| if args.batch: | |
| src = args.input_file if args.input_file else "-" | |
| if src == "-": | |
| input_texts = [line.strip() for line in sys.stdin if line.strip()] | |
| else: | |
| with open(src, "r", encoding="utf-8") as f: | |
| input_texts = [line.strip() for line in f if line.strip()] | |
| else: | |
| if args.input_file: | |
| if args.input_file == "-": | |
| input_texts = [sys.stdin.read().strip()] | |
| else: | |
| with open(args.input_file, "r", encoding="utf-8") as f: | |
| input_texts = [f.read().strip()] | |
| elif args.input_text: | |
| input_texts = [args.input_text] | |
| else: | |
| txt = sys.stdin.read().strip() | |
| if not txt: error_exit("No input provided. Use --input, --input-text, or pipe to stdin.") | |
| input_texts = [txt] | |
| if not input_texts: | |
| error_exit("No prompts to process.") | |
| if not args.quiet: | |
| logger.info(f"Initializing CompressionService (mode={args.mode})...") | |
| # 2. Initialize Service | |
| try: | |
| service = CompressionService( | |
| use_cache=not args.no_cache, | |
| redis_url=args.redis_url or os.getenv("NLPROXY_REDIS_URL"), | |
| privacy_mode=args.privacy_mode, | |
| models_dir=Path(args.models_dir) if args.models_dir else None, | |
| ) | |
| except Exception as e: | |
| error_exit(f"Failed to initialize CompressionService: {e}") | |
| # 3. Parse Manual Restrictions | |
| manual_restrictions: Optional[List[Restriction]] = None | |
| if args.restrictions: | |
| try: | |
| restrictions_data = json.loads(args.restrictions) | |
| manual_restrictions = [Restriction(**r) for r in restrictions_data] | |
| except (json.JSONDecodeError, TypeError) as e: | |
| error_exit(f"Invalid restrictions JSON: {e}") | |
| # 4. Execute Compression | |
| results: List[Dict[str, Any]] = [] | |
| start_time = time.time() | |
| if not args.quiet and len(input_texts) > 1: | |
| logger.info(f"Processing {len(input_texts)} prompts...") | |
| try: | |
| results = asyncio.run( | |
| asyncio.wait_for( | |
| service.compress_batch_async( | |
| texts=input_texts, | |
| aggressiveness=args.aggressiveness, | |
| mode=args.mode, | |
| nli_active=args.nli_active, | |
| language=args.language, | |
| privacy_mode=args.privacy_mode, | |
| ), | |
| timeout=args.timeout, | |
| ) | |
| ) | |
| except asyncio.TimeoutError: | |
| error_exit(f"Compression timed out after {args.timeout}s") | |
| except Exception as e: | |
| error_exit(f"Compression failed: {e}") | |
| elapsed = time.time() - start_time | |
| # 5. Format Output | |
| if args.format == "text" and not args.batch: | |
| output = results[0]["compressed_text"] | |
| elif args.format == "text" and args.batch: | |
| output = [r["compressed_text"] for r in results] | |
| else: | |
| output = results if (args.batch or args.format == "jsonl") else results[0] | |
| if isinstance(output, dict): | |
| output["_processing_time_ms"] = round(elapsed * 1000, 2) | |
| output["_prompts_processed"] = len(results) | |
| print_output(output, format=args.format, output_file=args.output_file, quiet=args.quiet) | |
| # 6. Summary | |
| if not args.quiet: | |
| if _RICH_AVAILABLE and Console: | |
| console = Console() | |
| table = Table(title="NLProxy Compression Summary", show_lines=True) | |
| table.add_column("Metric", style="cyan") | |
| table.add_column("Value", style="green") | |
| total_orig = sum(r.get("original_tokens", 0) for r in results) | |
| total_comp = sum(r.get("compressed_tokens", 0) for r in results) | |
| avg_ratio = sum(r.get("compression_ratio", 0) for r in results) / len(results) if results else 0 | |
| table.add_row("Prompts processed", str(len(results))) | |
| table.add_row("Total tokens (original)", f"{total_orig:,}") | |
| table.add_row("Total tokens (compressed)", f"{total_comp:,}") | |
| table.add_row("Tokens saved", f"{total_orig - total_comp:,} ({avg_ratio:.1%} avg)") | |
| table.add_row("Processing time", f"{elapsed:.2f}s") | |
| table.add_row("Avg latency/prompt", f"{(elapsed / len(results) * 1000):.1f}ms" if results else "N/A") | |
| console.print(table) | |
| else: | |
| print(f"\nProcessed {len(results)} prompt(s) in {elapsed:.2f}s", file=sys.stderr) | |
| success_exit(quiet=args.quiet) | |
| # ============================================================================= | |
| # PARSER (NO SUBPARSERS - FLAT STRUCTURE) | |
| # ============================================================================= | |
| def create_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| prog="nlproxy compress", | |
| description="Compress prompt(s) using NLProxy's semantic pipeline.", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=textwrap.dedent(""" | |
| Examples: | |
| nlproxy compress --input prueba.txt --output result.txt | |
| nlproxy compress --input-text "Hola mundo" --mode code | |
| cat prompts.txt | nlproxy compress --batch --format jsonl | |
| """), | |
| ) | |
| # Input/Output (type=str avoids TextIOWrapper errors) | |
| parser.add_argument("-i", "--input-text", type=str, help="Direct prompt text") | |
| parser.add_argument("-f", "--input-file", "--input", type=str, dest="input_file", help="Path to input file (use '-' for stdin)") | |
| parser.add_argument("-o", "--output-file", "--output", type=str, dest="output_file", help="Path to output file (use '-' for stdout)") | |
| # Core Flags | |
| parser.add_argument("--format", type=str, choices=["text", "json", "jsonl"], default="text", help="Output format") | |
| parser.add_argument("-b", "--batch", action="store_true", help="Process multiple prompts (one per line)") | |
| parser.add_argument("-a", "--aggressiveness", type=float, default=0.2, metavar="FLOAT", help="Compression aggressiveness: 0.0 to 1.0") | |
| parser.add_argument("-m", "--mode", type=str, choices=["general", "code", "finance", "legal"], default=os.getenv("NLPROXY_DEFAULT_MODE", "general"), help="Domain mode") | |
| parser.add_argument("--nli-active", action="store_true", help="Enable NLI-based semantic refinement") | |
| parser.add_argument("-l", "--language", type=str, help="Language code for segmentation") | |
| parser.add_argument("-p", "--privacy-mode", action="store_true", help="Enable strict PII handling") | |
| parser.add_argument("--no-cache", action="store_true", help="Disable result caching") | |
| parser.add_argument("--restrictions", type=str, help='JSON list of manual restrictions') | |
| parser.add_argument("-t", "--timeout", type=int, default=60, metavar="SECONDS", help="Timeout per prompt in seconds") | |
| # Global Pass-through | |
| parser.add_argument("-v", "--verbose", action="store_const", const="DEBUG", dest="log_level", default="INFO", help="Enable verbose logging") | |
| parser.add_argument("-q", "--quiet", action="store_true", help="Suppress non-essential output") | |
| parser.add_argument("--redis-url", type=str, help="Redis URL for caching") | |
| parser.add_argument("--models-dir", type=str, help="Directory containing pre-downloaded models") | |
| return parser | |
| def main(argv: Optional[List[str]] = None) -> int: | |
| parser = create_parser() | |
| args = parser.parse_args(argv) | |
| try: | |
| cmd_compress(args) | |
| return 0 | |
| except KeyboardInterrupt: | |
| error_exit("Interrupted by user", code=130, quiet=getattr(args, "quiet", False)) | |
| except BrokenPipeError: | |
| sys.stderr.close() | |
| return 0 | |
| except Exception as e: | |
| error_exit(str(e), quiet=getattr(args, "quiet", False)) | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |