text-to-json / app.py
suneeldk's picture
Update app.py
0d6ff7a verified
import gradio as gr
import json
import re
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
from peft import PeftModel
# ── Load model once at startup ──────────────────────────────
BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
LORA_MODEL = "suneeldk/json-extract" # ← change this
tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
)
model = PeftModel.from_pretrained(base_model, LORA_MODEL)
model = model.merge_and_unload()
model.eval()
# ── Stop generation when JSON is complete ───────────────────
class StopOnJsonComplete(StoppingCriteria):
"""Stop generating once we have a complete JSON object."""
def __init__(self, tokenizer, prompt_length):
self.tokenizer = tokenizer
self.prompt_length = prompt_length
def __call__(self, input_ids, scores, **kwargs):
new_tokens = input_ids[0][self.prompt_length:]
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
if not text.startswith("{"):
return False
# Count braces to detect complete JSON
depth = 0
for char in text:
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
return True # JSON object is complete, stop!
return False
# ── Extract first valid JSON from text ──────────────────────
def extract_json(text):
"""Find and return the first complete JSON object in text."""
depth = 0
start = None
for i, char in enumerate(text):
if char == "{":
if start is None:
start = i
depth += 1
elif char == "}":
depth -= 1
if depth == 0 and start is not None:
try:
return json.loads(text[start:i + 1])
except json.JSONDecodeError:
start = None
return None
# ── Auto-detect schema from text ────────────────────────────
def auto_schema(text):
text_lower = text.lower()
schema = {}
money_keywords = ["paid", "sent", "received", "cost", "price", "rupees", "rs",
"β‚Ή", "$", "bought", "sold", "charged", "fee", "salary",
"budget", "owes", "owe", "lent", "borrowed", "fare", "rent"]
if any(k in text_lower for k in money_keywords) or any(c.isdigit() for c in text):
schema["amount"] = "number|null"
person_keywords = ["to", "from", "with", "for", "by", "told", "asked",
"met", "called", "emailed", "messaged", "owes", "owe"]
if any(k in text_lower for k in person_keywords):
schema["person"] = "string|null"
date_keywords = ["jan", "feb", "mar", "apr", "may", "jun", "jul", "aug",
"sep", "oct", "nov", "dec", "monday", "tuesday", "wednesday",
"thursday", "friday", "saturday", "sunday", "today", "tomorrow",
"yesterday", "morning", "evening", "night", "on", "at", "pm", "am"]
if any(k in text_lower for k in date_keywords):
schema["date"] = "ISO date|null"
if any(k in text_lower for k in ["pm", "am", "morning", "evening", "night", "at"]):
schema["time"] = "string|null"
item_keywords = ["bought", "ordered", "purchased", "delivered", "shipped",
"kg", "litre", "pieces", "items", "pack", "bottle"]
if any(k in text_lower for k in item_keywords):
schema["item"] = "string|null"
schema["quantity"] = "string|null"
location_keywords = ["store", "shop", "restaurant", "station", "airport",
"hotel", "office"]
if any(k in text_lower for k in location_keywords):
schema["location"] = "string|null"
travel_keywords = ["train", "flight", "bus", "booked", "ticket", "pnr",
"travel", "trip", "journey"]
if any(k in text_lower for k in travel_keywords):
schema["from_location"] = "string|null"
schema["to_location"] = "string|null"
schema.pop("location", None)
meeting_keywords = ["meeting", "call", "discuss", "review", "presentation",
"interview", "appointment", "schedule"]
if any(k in text_lower for k in meeting_keywords):
schema["topic"] = "string|null"
schema["note"] = "string|null"
if len(schema) <= 1:
schema = {
"amount": "number|null",
"person": "string|null",
"date": "ISO date|null",
"note": "string|null",
}
return schema
# ── Inference function ──────────────────────────────────────
@spaces.GPU
def extract(text, custom_schema):
if not text.strip():
return "", ""
if custom_schema and custom_schema.strip():
try:
schema = json.loads(custom_schema)
except json.JSONDecodeError:
return "Invalid JSON schema.", ""
else:
schema = auto_schema(text)
schema_str = json.dumps(schema)
prompt = f"### Input: {text}\n### Schema: {schema_str}\n### Output:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_length = inputs["input_ids"].shape[1]
# Stop as soon as JSON is complete
stop_criteria = StoppingCriteriaList([
StopOnJsonComplete(tokenizer, prompt_length)
])
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=stop_criteria,
)
new_tokens = outputs[0][prompt_length:]
output_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
# Extract just the JSON, ignore any trailing garbage
parsed = extract_json(output_part)
if parsed:
return json.dumps(parsed, indent=2, ensure_ascii=False), json.dumps(schema, indent=2)
else:
return output_part, json.dumps(schema, indent=2)
# ── Example inputs ──────────────────────────────────────────
examples = [
["Paid 500 to Ravi for lunch on Jan 5"],
["Meeting with Sarah at 3pm tomorrow to discuss the project budget of $10,000"],
["Bought 3 kg of rice from Krishna Stores for 250 rupees on March 10"],
["Booked a train from Chennai to Bangalore on April 10 for 750 rupees"],
["Ravi owes me 300 for last week's dinner"],
["Ordered 2 pizzas and 1 coke from Dominos for 850 rupees"],
]
# ── Gradio UI ───────────────────────────────────────────────
with gr.Blocks(title="json-extract") as demo:
gr.Markdown(
"""
# json-extract
Extract structured JSON from natural language text.
Just type a sentence β€” the model auto-detects the right schema and extracts clean JSON.
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Input Text",
placeholder="e.g. Paid 500 to Ravi for lunch on Jan 5",
lines=3,
)
btn = gr.Button("Extract", variant="primary")
with gr.Accordion("Advanced: Custom Schema (optional)", open=False):
schema_input = gr.Textbox(
label="Custom JSON Schema",
placeholder='Leave empty for auto-detect, or enter e.g. {"amount": "number", "person": "string|null"}',
lines=3,
)
with gr.Column():
output = gr.Textbox(label="Extracted JSON", lines=10)
detected_schema = gr.Textbox(label="Schema Used", lines=5)
gr.Examples(
examples=examples,
inputs=[text_input],
)
btn.click(fn=extract, inputs=[text_input, schema_input], outputs=[output, detected_schema])
text_input.submit(fn=extract, inputs=[text_input, schema_input], outputs=[output, detected_schema])
demo.launch()