atd / app.py
cstr's picture
Update app.py
f15e7de verified
import os
import sys
import time
import subprocess
import urllib.request
import socket
import gradio as gr
import http.client
import urllib.parse
import xml.etree.ElementTree as ET
import tarfile
import shutil
# =============================================================================
# CONFIGURATION
# =============================================================================
REPO_URL = "https://github.com/Automattic/atd-server-next.git"
SERVER_DIR = "atd-server-next"
MODELS_DIR = os.path.join(SERVER_DIR, "models")
JAVA_DIR = "jdk8" # Folder to install portable Java
HOST = "127.0.0.1"
PORT = 1049
# Official Adoptium Java 8 (LTS) Linux x64 URL
JAVA_URL = "https://github.com/adoptium/temurin8-binaries/releases/download/jdk8u392-b08/OpenJDK8U-jre_x64_linux_hotspot_8u392b08.tar.gz"
JAVA_BIN = os.path.join(os.getcwd(), JAVA_DIR, "bin", "java")
MODEL_BASE_URL = "https://openatd.svn.wordpress.org/atd-server/models/"
MODEL_FILES = [
"cnetwork.bin", "cnetwork2.bin", "dictionary.txt", "edits.bin",
"endings.bin", "hnetwork.bin", "hnetwork2.bin", "hnetwork4.bin",
"lexicon.bin", "model.bin", "model.zip", "network3f.bin",
"network3p.bin", "not_misspelled.txt", "stringpool.bin", "trigrams.bin"
]
# =============================================================================
# PHASE -1: INSTALL JAVA 8 (The "Time Machine" Fix)
# =============================================================================
def install_java8():
"""Downloads and installs portable Java 8."""
print("--- [PHASE -1] CHECKING JAVA RUNTIME ---")
# If our local java binary exists, we assume it's installed
if os.path.exists(JAVA_BIN):
print(f"Portable Java 8 found at: {JAVA_BIN}")
return
print("Java 8 not found. Downloading portable runtime (approx 40MB)...")
tar_path = "java8.tar.gz"
try:
# Download
urllib.request.urlretrieve(JAVA_URL, tar_path)
print("Download complete. Extracting...")
# Extract
with tarfile.open(tar_path, "r:gz") as tar:
# Find the root folder name inside the tar
root_name = tar.getnames()[0].split('/')[0]
tar.extractall()
# Rename extracted folder to 'jdk8' for simplicity
if os.path.exists(JAVA_DIR):
shutil.rmtree(JAVA_DIR)
os.rename(root_name, JAVA_DIR)
# Cleanup
os.remove(tar_path)
# Set execute permissions
subprocess.run(["chmod", "+x", JAVA_BIN], check=True)
print(f"Java 8 successfully installed to {JAVA_DIR}")
except Exception as e:
print(f"FATAL: Failed to install Java 8. Error: {e}")
sys.exit(1)
# =============================================================================
# PHASE 1: REPO & MODELS
# =============================================================================
def setup_server():
# Ensure Java is ready first
install_java8()
print("\n--- [PHASE 0] CHECKING REPOSITORY ---")
if not os.path.exists(SERVER_DIR):
print(f"Cloning {REPO_URL}...")
subprocess.run(["git", "clone", "--depth", "1", REPO_URL, SERVER_DIR], check=True)
print("\n--- [PHASE 1] CHECKING MODELS ---")
if not os.path.exists(MODELS_DIR):
os.makedirs(MODELS_DIR, exist_ok=True)
for filename in MODEL_FILES:
filepath = os.path.join(MODELS_DIR, filename)
if not os.path.exists(filepath):
print(f"Downloading {filename}...")
try:
urllib.request.urlretrieve(MODEL_BASE_URL + filename, filepath)
except Exception as e:
print(f" -> FAILED: {e}")
print("\n--- [PHASE 2] COMPILING RULES ---")
# Only compile if grammar.bin doesn't exist (heuristic check)
if not os.path.exists(os.path.join(SERVER_DIR, "models", "grammar.bin")):
print("Compiling rules using Java 8...")
# Classpath separator is ':' on Linux
cp = "lib/sleep.jar:lib/moconti.jar:lib/spellutils.jar"
try:
subprocess.run(
[
JAVA_BIN,
"-Datd.lowmem=true",
"-Xmx1024M",
"-classpath", cp,
"sleep.console.TextConsole",
"utils/rules/rules.sl"
],
cwd=SERVER_DIR,
check=True
)
print("Rules compiled successfully.")
except subprocess.CalledProcessError as e:
print(f"Rule compilation warning: {e}")
# =============================================================================
# PHASE 2: SERVER RUNNER
# =============================================================================
def start_backend():
print("\n--- [PHASE 3] STARTING SERVER ---")
classpath = "lib/sleep.jar:lib/moconti.jar:lib/spellutils.jar"
sleep_cp = "lib:service/code"
cmd = [
JAVA_BIN,
"-Dfile.encoding=UTF-8",
"-XX:+AggressiveHeap",
"-XX:+UseParallelGC",
"-Datd.lowmem=true",
"-Dbind.interface=127.0.0.1",
f"-Dserver.port={PORT}",
f"-Dsleep.classpath={sleep_cp}",
"-Dsleep.debug=24",
"-classpath", classpath,
"httpd.Moconti",
"atdconfig.sl"
]
print(f"Launching with: {JAVA_BIN}")
return subprocess.Popen(cmd, cwd=SERVER_DIR)
def wait_for_port(timeout=60):
print(f"Waiting for port {PORT}...")
start = time.time()
while time.time() - start < timeout:
try:
with socket.create_connection((HOST, PORT), timeout=1):
print("Server is Online!")
return True
except (ConnectionRefusedError, OSError):
time.sleep(1)
return False
# =============================================================================
# PHASE 3: CLIENT
# =============================================================================
class AtDClient:
def check_document(self, text):
try:
conn = http.client.HTTPConnection(HOST, PORT, timeout=5)
params = urllib.parse.urlencode({'key': 'gradio', 'data': text})
headers = {"Content-Type": "application/x-www-form-urlencoded"}
conn.request("POST", "/checkDocument", params, headers)
resp = conn.getresponse()
if resp.status != 200: return []
xml_text = resp.read().decode('utf-8', errors='ignore')
if not xml_text.strip().startswith("<"): return []
root = ET.fromstring(xml_text)
errors = []
for e in root.findall('error'):
err = {
'string': e.find('string').text,
'description': e.find('description').text,
'type': e.find('type').text,
'precontext': e.find('precontext').text or "",
'suggestions': []
}
sug = e.find('suggestions')
if sug is not None:
err['suggestions'] = [o.text for o in sug.findall('option') if o.text]
errors.append(err)
return errors
except Exception as e:
print(f"Client Error: {e}")
return []
client = AtDClient()
def analyze_text(text):
if not text.strip(): return []
errors = client.check_document(text)
output = []
last_pos = 0
for err in errors:
word = err['string']
search_start = last_pos
if err['precontext']:
context_idx = text.find(err['precontext'], last_pos)
if context_idx != -1:
search_start = context_idx + len(err['precontext'])
idx = text.find(word, search_start)
if idx != -1:
if idx > last_pos:
output.append((text[last_pos:idx], None))
label = f"{err['type']}: {err['description']}"
if err['suggestions']:
label += f" -> {', '.join(err['suggestions'][:3])}"
output.append((text[idx:idx+len(word)], label))
last_pos = idx + len(word)
if last_pos < len(text):
output.append((text[last_pos:], None))
return output
if __name__ == "__main__":
setup_server()
server_proc = start_backend()
time.sleep(5)
if server_proc.poll() is not None:
print("FATAL: Java server exited immediately.")
sys.exit(1)
if wait_for_port(timeout=120):
with gr.Blocks(title="AtD Self-Hosted") as demo:
gr.Markdown("# 🛡️ After The Deadline")
gr.Markdown("Running on Portable Java 8")
with gr.Row():
inp = gr.Textbox(label="Input", placeholder="Type here... e.g., I has a error.", lines=6)
out = gr.HighlightedText(label="Analysis", combine_adjacent=True)
btn = gr.Button("Check Text", variant="primary")
btn.click(analyze_text, inputs=inp, outputs=out)
demo.launch(server_name="0.0.0.0", server_port=7860)
else:
print("FATAL: Server did not start (Timeout).")
server_proc.kill()