import re import time from collections import OrderedDict import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # 1. Configuration MODEL_NAME = "google/flan-t5-small" MAX_INPUT_TOKENS = 512 MAX_OUTPUT_TOKENS = 300 BATCH_SEPARATOR = "\n\n" CACHE_MAX_ITEMS = 512 CACHE_VERSION = "v3" APP_START_TIME = time.time() print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) model.eval() print("Model loaded.") # 2. In-memory LRU cache _cache: OrderedDict[str, tuple[str, int]] = OrderedDict() def cache_get(key: str): if key not in _cache: return None _cache.move_to_end(key) return _cache[key] def cache_set(key: str, value: tuple[str, int]): _cache[key] = value _cache.move_to_end(key) if len(_cache) > CACHE_MAX_ITEMS: _cache.popitem(last=False) # 3. Utilities def normalize_text(text: str) -> str: text = text.strip() text = re.sub(r"\s+", " ", text) return text def split_batch(text: str) -> list[str]: return [normalize_text(t) for t in text.split(BATCH_SEPARATOR) if t.strip()] def split_sentences(text: str) -> list[str]: sentences = re.split(r'(?<=[.!?])\s+', text) return [s.strip() for s in sentences if s.strip()] def build_prompt(mode: str, text: str) -> str: if mode == "rewrite": return f"Rewrite the following text clearly and naturally:\n{text}" if mode == "paraphrase": return f"Paraphrase the following text using different wording:\n{text}" if mode == "normalize": return ( "Rewrite the following text as a single, clear, simple sentence " "using correct grammar:\n" f"{text}" ) if mode == "definition": return ( "Give a short and clear definition (1–2 sentences) of the following:\n" f"{text}" ) if mode == "article": return ( "Write a 400 word article on the following topic.\n" "Rules:\n" "- Neutral tone\n" "- Short paragraphs\n" "- Simple English\n" "- No markdown\n" "- No emojis\n\n" f"Topic: {text}" ) return text def build_cache_key(mode: str, text: str) -> str: return f"{CACHE_VERSION}|{MODEL_NAME}|{mode}|{text}" # 4. Inference @torch.inference_mode() def generate_text(prompt: str) -> tuple[str, int]: start = time.perf_counter() inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKENS ) outputs = model.generate( **inputs, max_new_tokens=MAX_OUTPUT_TOKENS, num_beams=4, early_stopping=True, no_repeat_ngram_size=3 ) text = tokenizer.decode(outputs[0], skip_special_tokens=True) latency_ms = int((time.perf_counter() - start) * 1000) return text, latency_ms # ----------------------------- # JSON-only API wrappers # ----------------------------- def json_rewrite_api(payload: dict): """ JSON contract: { "text": string | list[string], "mode": string, "batch": bool } """ try: mode = payload.get("mode", "rewrite") text = payload.get("text") is_batch = payload.get("batch", False) if not text: return {"ok": False, "error": "text is required"} # Batch as list if is_batch: if not isinstance(text, list): return {"ok": False, "error": "batch=true requires text as list"} joined = BATCH_SEPARATOR.join(text) output = process(joined, mode) results = output.split(BATCH_SEPARATOR) return { "ok": True, "mode": mode, "count": len(results), "results": results } # Single output = process(text, mode) return { "ok": True, "mode": mode, "result": output } except Exception as ex: return { "ok": False, "error": str(ex) } def json_health_api(_: dict = None): uptime = int(time.time() - APP_START_TIME) return { "ok": True, "model": MODEL_NAME, "cache_items": len(_cache), "uptime_seconds": uptime } # 5. Main API function def process(text: str, mode: str) -> str: if not text or not text.strip(): return "" items = split_batch(text) results = [] for item in items: if mode == "normalize_sentences": sentences = split_sentences(item) normalized = [] total_latency = 0 for sentence in sentences: cache_key = build_cache_key(mode, sentence) cached = cache_get(cache_key) if cached: output, _ = cached normalized.append(output) continue prompt = build_prompt("normalize", sentence) output, latency_ms = generate_text(prompt) cache_set(cache_key, (output, latency_ms)) normalized.append(output) total_latency += latency_ms final_text = " ".join(normalized) results.append( f"{final_text}\n(latency_ms={total_latency}, cached=partial)" ) continue cache_key = build_cache_key(mode, item) cached = cache_get(cache_key) if cached: output, _ = cached results.append(f"{output}\n(latency_ms=0, cached=true)") continue prompt = build_prompt(mode, item) output, latency_ms = generate_text(prompt) cache_set(cache_key, (output, latency_ms)) results.append(f"{output}\n(latency_ms={latency_ms}, cached=false)") return BATCH_SEPARATOR.join(results) # 6. Health endpoint (NO INPUTS) def health(): uptime = int(time.time() - APP_START_TIME) return { "status": "ok", "model": MODEL_NAME, "cache_items": len(_cache), "uptime_seconds": uptime } # 7. Gradio app with multiple endpoints demo = gr.Interface( fn=process, inputs=[ gr.Textbox(lines=10, label="Input text"), gr.Dropdown( choices=[ "rewrite", "paraphrase", "normalize", "normalize_sentences", "definition", "article" ], value="rewrite", label="Mode" ) ], outputs=gr.Textbox(lines=14, label="Output"), title="Free Text Rewrite Service", description="Production-style HF Space with cache, batch, and health endpoint." ) health_api = gr.Interface( fn=health, inputs=[], outputs="json" ) app = gr.TabbedInterface( [demo, health_api], ["Rewrite", "Health"] ) app.queue(False) app.launch( server_name="0.0.0.0", server_port=7860, show_error=True )