Spaces:
Sleeping
Sleeping
| import re | |
| import time | |
| from collections import OrderedDict | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # 1. Configuration | |
| MODEL_NAME = "google/flan-t5-small" | |
| MAX_INPUT_TOKENS = 512 | |
| MAX_OUTPUT_TOKENS = 300 | |
| BATCH_SEPARATOR = "\n\n" | |
| CACHE_MAX_ITEMS = 512 | |
| CACHE_VERSION = "v3" | |
| APP_START_TIME = time.time() | |
| print("Loading model...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| print("Model loaded.") | |
| # 2. In-memory LRU cache | |
| _cache: OrderedDict[str, tuple[str, int]] = OrderedDict() | |
| def cache_get(key: str): | |
| if key not in _cache: | |
| return None | |
| _cache.move_to_end(key) | |
| return _cache[key] | |
| def cache_set(key: str, value: tuple[str, int]): | |
| _cache[key] = value | |
| _cache.move_to_end(key) | |
| if len(_cache) > CACHE_MAX_ITEMS: | |
| _cache.popitem(last=False) | |
| # 3. Utilities | |
| def normalize_text(text: str) -> str: | |
| text = text.strip() | |
| text = re.sub(r"\s+", " ", text) | |
| return text | |
| def split_batch(text: str) -> list[str]: | |
| return [normalize_text(t) for t in text.split(BATCH_SEPARATOR) if t.strip()] | |
| def split_sentences(text: str) -> list[str]: | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| return [s.strip() for s in sentences if s.strip()] | |
| def build_prompt(mode: str, text: str) -> str: | |
| if mode == "rewrite": | |
| return f"Rewrite the following text clearly and naturally:\n{text}" | |
| if mode == "paraphrase": | |
| return f"Paraphrase the following text using different wording:\n{text}" | |
| if mode == "normalize": | |
| return ( | |
| "Rewrite the following text as a single, clear, simple sentence " | |
| "using correct grammar:\n" | |
| f"{text}" | |
| ) | |
| if mode == "definition": | |
| return ( | |
| "Give a short and clear definition (1–2 sentences) of the following:\n" | |
| f"{text}" | |
| ) | |
| if mode == "article": | |
| return ( | |
| "Write a 400 word article on the following topic.\n" | |
| "Rules:\n" | |
| "- Neutral tone\n" | |
| "- Short paragraphs\n" | |
| "- Simple English\n" | |
| "- No markdown\n" | |
| "- No emojis\n\n" | |
| f"Topic: {text}" | |
| ) | |
| return text | |
| def build_cache_key(mode: str, text: str) -> str: | |
| return f"{CACHE_VERSION}|{MODEL_NAME}|{mode}|{text}" | |
| # 4. Inference | |
| def generate_text(prompt: str) -> tuple[str, int]: | |
| start = time.perf_counter() | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_INPUT_TOKENS | |
| ) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_OUTPUT_TOKENS, | |
| num_beams=4, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3 | |
| ) | |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| latency_ms = int((time.perf_counter() - start) * 1000) | |
| return text, latency_ms | |
| # ----------------------------- | |
| # JSON-only API wrappers | |
| # ----------------------------- | |
| def json_rewrite_api(payload: dict): | |
| """ | |
| JSON contract: | |
| { | |
| "text": string | list[string], | |
| "mode": string, | |
| "batch": bool | |
| } | |
| """ | |
| try: | |
| mode = payload.get("mode", "rewrite") | |
| text = payload.get("text") | |
| is_batch = payload.get("batch", False) | |
| if not text: | |
| return {"ok": False, "error": "text is required"} | |
| # Batch as list | |
| if is_batch: | |
| if not isinstance(text, list): | |
| return {"ok": False, "error": "batch=true requires text as list"} | |
| joined = BATCH_SEPARATOR.join(text) | |
| output = process(joined, mode) | |
| results = output.split(BATCH_SEPARATOR) | |
| return { | |
| "ok": True, | |
| "mode": mode, | |
| "count": len(results), | |
| "results": results | |
| } | |
| # Single | |
| output = process(text, mode) | |
| return { | |
| "ok": True, | |
| "mode": mode, | |
| "result": output | |
| } | |
| except Exception as ex: | |
| return { | |
| "ok": False, | |
| "error": str(ex) | |
| } | |
| def json_health_api(_: dict = None): | |
| uptime = int(time.time() - APP_START_TIME) | |
| return { | |
| "ok": True, | |
| "model": MODEL_NAME, | |
| "cache_items": len(_cache), | |
| "uptime_seconds": uptime | |
| } | |
| # 5. Main API function | |
| def process(text: str, mode: str) -> str: | |
| if not text or not text.strip(): | |
| return "" | |
| items = split_batch(text) | |
| results = [] | |
| for item in items: | |
| if mode == "normalize_sentences": | |
| sentences = split_sentences(item) | |
| normalized = [] | |
| total_latency = 0 | |
| for sentence in sentences: | |
| cache_key = build_cache_key(mode, sentence) | |
| cached = cache_get(cache_key) | |
| if cached: | |
| output, _ = cached | |
| normalized.append(output) | |
| continue | |
| prompt = build_prompt("normalize", sentence) | |
| output, latency_ms = generate_text(prompt) | |
| cache_set(cache_key, (output, latency_ms)) | |
| normalized.append(output) | |
| total_latency += latency_ms | |
| final_text = " ".join(normalized) | |
| results.append( | |
| f"{final_text}\n(latency_ms={total_latency}, cached=partial)" | |
| ) | |
| continue | |
| cache_key = build_cache_key(mode, item) | |
| cached = cache_get(cache_key) | |
| if cached: | |
| output, _ = cached | |
| results.append(f"{output}\n(latency_ms=0, cached=true)") | |
| continue | |
| prompt = build_prompt(mode, item) | |
| output, latency_ms = generate_text(prompt) | |
| cache_set(cache_key, (output, latency_ms)) | |
| results.append(f"{output}\n(latency_ms={latency_ms}, cached=false)") | |
| return BATCH_SEPARATOR.join(results) | |
| # 6. Health endpoint (NO INPUTS) | |
| def health(): | |
| uptime = int(time.time() - APP_START_TIME) | |
| return { | |
| "status": "ok", | |
| "model": MODEL_NAME, | |
| "cache_items": len(_cache), | |
| "uptime_seconds": uptime | |
| } | |
| # 7. Gradio app with multiple endpoints | |
| demo = gr.Interface( | |
| fn=process, | |
| inputs=[ | |
| gr.Textbox(lines=10, label="Input text"), | |
| gr.Dropdown( | |
| choices=[ | |
| "rewrite", | |
| "paraphrase", | |
| "normalize", | |
| "normalize_sentences", | |
| "definition", | |
| "article" | |
| ], | |
| value="rewrite", | |
| label="Mode" | |
| ) | |
| ], | |
| outputs=gr.Textbox(lines=14, label="Output"), | |
| title="Free Text Rewrite Service", | |
| description="Production-style HF Space with cache, batch, and health endpoint." | |
| ) | |
| health_api = gr.Interface( | |
| fn=health, | |
| inputs=[], | |
| outputs="json" | |
| ) | |
| app = gr.TabbedInterface( | |
| [demo, health_api], | |
| ["Rewrite", "Health"] | |
| ) | |
| app.queue(False) | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |