| """ |
| Interactive inference script. |
| Run: python scripts/run_inference.py --config configs/inference_config.yaml |
| """ |
|
|
| import click |
| import yaml |
| from rich.console import Console |
| from rich.panel import Panel |
| from rich.table import Table |
| from src.inference.corrector import AcademicCorrector |
|
|
| console = Console() |
|
|
|
|
| @click.command() |
| @click.option("--config", default="configs/inference_config.yaml") |
| @click.option("--text", default=None, help="Text to correct") |
| @click.option("--master-copy", default=None, help="Optional master copy for style matching") |
| @click.option("--style-alpha", default=0.6, help="Style blend weight (0=master, 1=user)") |
| def run_inference(config: str, text: str, master_copy: str, style_alpha: float): |
| """Run inference on text input.""" |
| with open(config) as f: |
| cfg = yaml.safe_load(f) |
|
|
| console.print("[bold cyan]Loading model...[/]") |
| corrector = AcademicCorrector(cfg) |
| console.print("[bold green]✓ Model loaded[/]") |
|
|
| if text: |
| result = corrector.correct(text, master_copy=master_copy, style_alpha=style_alpha) |
| console.print(Panel(result.original, title="Original", border_style="red")) |
| console.print(Panel(result.corrected, title="Corrected", border_style="green")) |
| table = Table(title="Metrics") |
| table.add_column("Metric") |
| table.add_column("Value") |
| table.add_row("Style Similarity", f"{result.style_similarity:.4f}") |
| table.add_row("AWL Coverage", f"{result.awl_coverage:.4f}") |
| for k, v in result.readability.items(): |
| table.add_row(k, f"{v:.2f}") |
| console.print(table) |
| else: |
| console.print("[bold yellow]Interactive mode. Type text to correct (Ctrl+C to exit).[/]") |
| while True: |
| try: |
| console.print() |
| user_input = console.input("[bold cyan]Enter text: [/]") |
| if not user_input.strip(): |
| continue |
| result = corrector.correct(user_input, style_alpha=style_alpha) |
| console.print(Panel(result.corrected, title="Corrected", border_style="green")) |
| console.print(f" Style: {result.style_similarity:.3f} | AWL: {result.awl_coverage:.3f}") |
| except KeyboardInterrupt: |
| console.print("\n[bold red]Goodbye![/]") |
| break |
|
|
|
|
| if __name__ == "__main__": |
| run_inference() |
|
|