import gradio as gr import json import re import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList from peft import PeftModel # ── Load model once at startup ────────────────────────────── BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" LORA_MODEL = "suneeldk/json-extract" # ← change this tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=bnb_config, device_map="auto", ) model = PeftModel.from_pretrained(base_model, LORA_MODEL) model = model.merge_and_unload() model.eval() # ── Stop generation when JSON is complete ─────────────────── class StopOnJsonComplete(StoppingCriteria): """Stop generating once we have a complete JSON object.""" def __init__(self, tokenizer, prompt_length): self.tokenizer = tokenizer self.prompt_length = prompt_length def __call__(self, input_ids, scores, **kwargs): new_tokens = input_ids[0][self.prompt_length:] text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() if not text.startswith("{"): return False # Count braces to detect complete JSON depth = 0 for char in text: if char == "{": depth += 1 elif char == "}": depth -= 1 if depth == 0: return True # JSON object is complete, stop! return False # ── Extract first valid JSON from text ────────────────────── def extract_json(text): """Find and return the first complete JSON object in text.""" depth = 0 start = None for i, char in enumerate(text): if char == "{": if start is None: start = i depth += 1 elif char == "}": depth -= 1 if depth == 0 and start is not None: try: return json.loads(text[start:i + 1]) except json.JSONDecodeError: start = None return None # ── Auto-detect schema from text ──────────────────────────── def auto_schema(text): text_lower = text.lower() schema = {} money_keywords = ["paid", "sent", "received", "cost", "price", "rupees", "rs", "₹", "$", "bought", "sold", "charged", "fee", "salary", "budget", "owes", "owe", "lent", "borrowed", "fare", "rent"] if any(k in text_lower for k in money_keywords) or any(c.isdigit() for c in text): schema["amount"] = "number|null" person_keywords = ["to", "from", "with", "for", "by", "told", "asked", "met", "called", "emailed", "messaged", "owes", "owe"] if any(k in text_lower for k in person_keywords): schema["person"] = "string|null" date_keywords = ["jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", "sep", "oct", "nov", "dec", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "today", "tomorrow", "yesterday", "morning", "evening", "night", "on", "at", "pm", "am"] if any(k in text_lower for k in date_keywords): schema["date"] = "ISO date|null" if any(k in text_lower for k in ["pm", "am", "morning", "evening", "night", "at"]): schema["time"] = "string|null" item_keywords = ["bought", "ordered", "purchased", "delivered", "shipped", "kg", "litre", "pieces", "items", "pack", "bottle"] if any(k in text_lower for k in item_keywords): schema["item"] = "string|null" schema["quantity"] = "string|null" location_keywords = ["store", "shop", "restaurant", "station", "airport", "hotel", "office"] if any(k in text_lower for k in location_keywords): schema["location"] = "string|null" travel_keywords = ["train", "flight", "bus", "booked", "ticket", "pnr", "travel", "trip", "journey"] if any(k in text_lower for k in travel_keywords): schema["from_location"] = "string|null" schema["to_location"] = "string|null" schema.pop("location", None) meeting_keywords = ["meeting", "call", "discuss", "review", "presentation", "interview", "appointment", "schedule"] if any(k in text_lower for k in meeting_keywords): schema["topic"] = "string|null" schema["note"] = "string|null" if len(schema) <= 1: schema = { "amount": "number|null", "person": "string|null", "date": "ISO date|null", "note": "string|null", } return schema # ── Inference function ────────────────────────────────────── @spaces.GPU def extract(text, custom_schema): if not text.strip(): return "", "" if custom_schema and custom_schema.strip(): try: schema = json.loads(custom_schema) except json.JSONDecodeError: return "Invalid JSON schema.", "" else: schema = auto_schema(text) schema_str = json.dumps(schema) prompt = f"### Input: {text}\n### Schema: {schema_str}\n### Output:" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) prompt_length = inputs["input_ids"].shape[1] # Stop as soon as JSON is complete stop_criteria = StoppingCriteriaList([ StopOnJsonComplete(tokenizer, prompt_length) ]) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=128, do_sample=False, pad_token_id=tokenizer.eos_token_id, stopping_criteria=stop_criteria, ) new_tokens = outputs[0][prompt_length:] output_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() # Extract just the JSON, ignore any trailing garbage parsed = extract_json(output_part) if parsed: return json.dumps(parsed, indent=2, ensure_ascii=False), json.dumps(schema, indent=2) else: return output_part, json.dumps(schema, indent=2) # ── Example inputs ────────────────────────────────────────── examples = [ ["Paid 500 to Ravi for lunch on Jan 5"], ["Meeting with Sarah at 3pm tomorrow to discuss the project budget of $10,000"], ["Bought 3 kg of rice from Krishna Stores for 250 rupees on March 10"], ["Booked a train from Chennai to Bangalore on April 10 for 750 rupees"], ["Ravi owes me 300 for last week's dinner"], ["Ordered 2 pizzas and 1 coke from Dominos for 850 rupees"], ] # ── Gradio UI ─────────────────────────────────────────────── with gr.Blocks(title="json-extract") as demo: gr.Markdown( """ # json-extract Extract structured JSON from natural language text. Just type a sentence — the model auto-detects the right schema and extracts clean JSON. """ ) with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Input Text", placeholder="e.g. Paid 500 to Ravi for lunch on Jan 5", lines=3, ) btn = gr.Button("Extract", variant="primary") with gr.Accordion("Advanced: Custom Schema (optional)", open=False): schema_input = gr.Textbox( label="Custom JSON Schema", placeholder='Leave empty for auto-detect, or enter e.g. {"amount": "number", "person": "string|null"}', lines=3, ) with gr.Column(): output = gr.Textbox(label="Extracted JSON", lines=10) detected_schema = gr.Textbox(label="Schema Used", lines=5) gr.Examples( examples=examples, inputs=[text_input], ) btn.click(fn=extract, inputs=[text_input, schema_input], outputs=[output, detected_schema]) text_input.submit(fn=extract, inputs=[text_input, schema_input], outputs=[output, detected_schema]) demo.launch()