File size: 2,379 Bytes
3df5819 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | """
Interactive inference script.
Run: python scripts/run_inference.py --config configs/inference_config.yaml
"""
import click
import yaml
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from src.inference.corrector import AcademicCorrector
console = Console()
@click.command()
@click.option("--config", default="configs/inference_config.yaml")
@click.option("--text", default=None, help="Text to correct")
@click.option("--master-copy", default=None, help="Optional master copy for style matching")
@click.option("--style-alpha", default=0.6, help="Style blend weight (0=master, 1=user)")
def run_inference(config: str, text: str, master_copy: str, style_alpha: float):
"""Run inference on text input."""
with open(config) as f:
cfg = yaml.safe_load(f)
console.print("[bold cyan]Loading model...[/]")
corrector = AcademicCorrector(cfg)
console.print("[bold green]✓ Model loaded[/]")
if text:
result = corrector.correct(text, master_copy=master_copy, style_alpha=style_alpha)
console.print(Panel(result.original, title="Original", border_style="red"))
console.print(Panel(result.corrected, title="Corrected", border_style="green"))
table = Table(title="Metrics")
table.add_column("Metric")
table.add_column("Value")
table.add_row("Style Similarity", f"{result.style_similarity:.4f}")
table.add_row("AWL Coverage", f"{result.awl_coverage:.4f}")
for k, v in result.readability.items():
table.add_row(k, f"{v:.2f}")
console.print(table)
else:
console.print("[bold yellow]Interactive mode. Type text to correct (Ctrl+C to exit).[/]")
while True:
try:
console.print()
user_input = console.input("[bold cyan]Enter text: [/]")
if not user_input.strip():
continue
result = corrector.correct(user_input, style_alpha=style_alpha)
console.print(Panel(result.corrected, title="Corrected", border_style="green"))
console.print(f" Style: {result.style_similarity:.3f} | AWL: {result.awl_coverage:.3f}")
except KeyboardInterrupt:
console.print("\n[bold red]Goodbye![/]")
break
if __name__ == "__main__":
run_inference()
|