rewrite-service / app.py
ankitdsmb's picture
Update app.py
0015351 verified
import re
import time
from collections import OrderedDict
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# 1. Configuration
MODEL_NAME = "google/flan-t5-small"
MAX_INPUT_TOKENS = 512
MAX_OUTPUT_TOKENS = 300
BATCH_SEPARATOR = "\n\n"
CACHE_MAX_ITEMS = 512
CACHE_VERSION = "v3"
APP_START_TIME = time.time()
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.eval()
print("Model loaded.")
# 2. In-memory LRU cache
_cache: OrderedDict[str, tuple[str, int]] = OrderedDict()
def cache_get(key: str):
if key not in _cache:
return None
_cache.move_to_end(key)
return _cache[key]
def cache_set(key: str, value: tuple[str, int]):
_cache[key] = value
_cache.move_to_end(key)
if len(_cache) > CACHE_MAX_ITEMS:
_cache.popitem(last=False)
# 3. Utilities
def normalize_text(text: str) -> str:
text = text.strip()
text = re.sub(r"\s+", " ", text)
return text
def split_batch(text: str) -> list[str]:
return [normalize_text(t) for t in text.split(BATCH_SEPARATOR) if t.strip()]
def split_sentences(text: str) -> list[str]:
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
def build_prompt(mode: str, text: str) -> str:
if mode == "rewrite":
return f"Rewrite the following text clearly and naturally:\n{text}"
if mode == "paraphrase":
return f"Paraphrase the following text using different wording:\n{text}"
if mode == "normalize":
return (
"Rewrite the following text as a single, clear, simple sentence "
"using correct grammar:\n"
f"{text}"
)
if mode == "definition":
return (
"Give a short and clear definition (1–2 sentences) of the following:\n"
f"{text}"
)
if mode == "article":
return (
"Write a 400 word article on the following topic.\n"
"Rules:\n"
"- Neutral tone\n"
"- Short paragraphs\n"
"- Simple English\n"
"- No markdown\n"
"- No emojis\n\n"
f"Topic: {text}"
)
return text
def build_cache_key(mode: str, text: str) -> str:
return f"{CACHE_VERSION}|{MODEL_NAME}|{mode}|{text}"
# 4. Inference
@torch.inference_mode()
def generate_text(prompt: str) -> tuple[str, int]:
start = time.perf_counter()
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_TOKENS
)
outputs = model.generate(
**inputs,
max_new_tokens=MAX_OUTPUT_TOKENS,
num_beams=4,
early_stopping=True,
no_repeat_ngram_size=3
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
latency_ms = int((time.perf_counter() - start) * 1000)
return text, latency_ms
# -----------------------------
# JSON-only API wrappers
# -----------------------------
def json_rewrite_api(payload: dict):
"""
JSON contract:
{
"text": string | list[string],
"mode": string,
"batch": bool
}
"""
try:
mode = payload.get("mode", "rewrite")
text = payload.get("text")
is_batch = payload.get("batch", False)
if not text:
return {"ok": False, "error": "text is required"}
# Batch as list
if is_batch:
if not isinstance(text, list):
return {"ok": False, "error": "batch=true requires text as list"}
joined = BATCH_SEPARATOR.join(text)
output = process(joined, mode)
results = output.split(BATCH_SEPARATOR)
return {
"ok": True,
"mode": mode,
"count": len(results),
"results": results
}
# Single
output = process(text, mode)
return {
"ok": True,
"mode": mode,
"result": output
}
except Exception as ex:
return {
"ok": False,
"error": str(ex)
}
def json_health_api(_: dict = None):
uptime = int(time.time() - APP_START_TIME)
return {
"ok": True,
"model": MODEL_NAME,
"cache_items": len(_cache),
"uptime_seconds": uptime
}
# 5. Main API function
def process(text: str, mode: str) -> str:
if not text or not text.strip():
return ""
items = split_batch(text)
results = []
for item in items:
if mode == "normalize_sentences":
sentences = split_sentences(item)
normalized = []
total_latency = 0
for sentence in sentences:
cache_key = build_cache_key(mode, sentence)
cached = cache_get(cache_key)
if cached:
output, _ = cached
normalized.append(output)
continue
prompt = build_prompt("normalize", sentence)
output, latency_ms = generate_text(prompt)
cache_set(cache_key, (output, latency_ms))
normalized.append(output)
total_latency += latency_ms
final_text = " ".join(normalized)
results.append(
f"{final_text}\n(latency_ms={total_latency}, cached=partial)"
)
continue
cache_key = build_cache_key(mode, item)
cached = cache_get(cache_key)
if cached:
output, _ = cached
results.append(f"{output}\n(latency_ms=0, cached=true)")
continue
prompt = build_prompt(mode, item)
output, latency_ms = generate_text(prompt)
cache_set(cache_key, (output, latency_ms))
results.append(f"{output}\n(latency_ms={latency_ms}, cached=false)")
return BATCH_SEPARATOR.join(results)
# 6. Health endpoint (NO INPUTS)
def health():
uptime = int(time.time() - APP_START_TIME)
return {
"status": "ok",
"model": MODEL_NAME,
"cache_items": len(_cache),
"uptime_seconds": uptime
}
# 7. Gradio app with multiple endpoints
demo = gr.Interface(
fn=process,
inputs=[
gr.Textbox(lines=10, label="Input text"),
gr.Dropdown(
choices=[
"rewrite",
"paraphrase",
"normalize",
"normalize_sentences",
"definition",
"article"
],
value="rewrite",
label="Mode"
)
],
outputs=gr.Textbox(lines=14, label="Output"),
title="Free Text Rewrite Service",
description="Production-style HF Space with cache, batch, and health endpoint."
)
health_api = gr.Interface(
fn=health,
inputs=[],
outputs="json"
)
app = gr.TabbedInterface(
[demo, health_api],
["Rewrite", "Health"]
)
app.queue(False)
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)