Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import re | |
| import spaces | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList | |
| from peft import PeftModel | |
| # ββ Load model once at startup ββββββββββββββββββββββββββββββ | |
| BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" | |
| LORA_MODEL = "suneeldk/json-extract" # β change this | |
| tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_MODEL) | |
| model = model.merge_and_unload() | |
| model.eval() | |
| # ββ Stop generation when JSON is complete βββββββββββββββββββ | |
| class StopOnJsonComplete(StoppingCriteria): | |
| """Stop generating once we have a complete JSON object.""" | |
| def __init__(self, tokenizer, prompt_length): | |
| self.tokenizer = tokenizer | |
| self.prompt_length = prompt_length | |
| def __call__(self, input_ids, scores, **kwargs): | |
| new_tokens = input_ids[0][self.prompt_length:] | |
| text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| if not text.startswith("{"): | |
| return False | |
| # Count braces to detect complete JSON | |
| depth = 0 | |
| for char in text: | |
| if char == "{": | |
| depth += 1 | |
| elif char == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| return True # JSON object is complete, stop! | |
| return False | |
| # ββ Extract first valid JSON from text ββββββββββββββββββββββ | |
| def extract_json(text): | |
| """Find and return the first complete JSON object in text.""" | |
| depth = 0 | |
| start = None | |
| for i, char in enumerate(text): | |
| if char == "{": | |
| if start is None: | |
| start = i | |
| depth += 1 | |
| elif char == "}": | |
| depth -= 1 | |
| if depth == 0 and start is not None: | |
| try: | |
| return json.loads(text[start:i + 1]) | |
| except json.JSONDecodeError: | |
| start = None | |
| return None | |
| # ββ Auto-detect schema from text ββββββββββββββββββββββββββββ | |
| def auto_schema(text): | |
| text_lower = text.lower() | |
| schema = {} | |
| money_keywords = ["paid", "sent", "received", "cost", "price", "rupees", "rs", | |
| "βΉ", "$", "bought", "sold", "charged", "fee", "salary", | |
| "budget", "owes", "owe", "lent", "borrowed", "fare", "rent"] | |
| if any(k in text_lower for k in money_keywords) or any(c.isdigit() for c in text): | |
| schema["amount"] = "number|null" | |
| person_keywords = ["to", "from", "with", "for", "by", "told", "asked", | |
| "met", "called", "emailed", "messaged", "owes", "owe"] | |
| if any(k in text_lower for k in person_keywords): | |
| schema["person"] = "string|null" | |
| date_keywords = ["jan", "feb", "mar", "apr", "may", "jun", "jul", "aug", | |
| "sep", "oct", "nov", "dec", "monday", "tuesday", "wednesday", | |
| "thursday", "friday", "saturday", "sunday", "today", "tomorrow", | |
| "yesterday", "morning", "evening", "night", "on", "at", "pm", "am"] | |
| if any(k in text_lower for k in date_keywords): | |
| schema["date"] = "ISO date|null" | |
| if any(k in text_lower for k in ["pm", "am", "morning", "evening", "night", "at"]): | |
| schema["time"] = "string|null" | |
| item_keywords = ["bought", "ordered", "purchased", "delivered", "shipped", | |
| "kg", "litre", "pieces", "items", "pack", "bottle"] | |
| if any(k in text_lower for k in item_keywords): | |
| schema["item"] = "string|null" | |
| schema["quantity"] = "string|null" | |
| location_keywords = ["store", "shop", "restaurant", "station", "airport", | |
| "hotel", "office"] | |
| if any(k in text_lower for k in location_keywords): | |
| schema["location"] = "string|null" | |
| travel_keywords = ["train", "flight", "bus", "booked", "ticket", "pnr", | |
| "travel", "trip", "journey"] | |
| if any(k in text_lower for k in travel_keywords): | |
| schema["from_location"] = "string|null" | |
| schema["to_location"] = "string|null" | |
| schema.pop("location", None) | |
| meeting_keywords = ["meeting", "call", "discuss", "review", "presentation", | |
| "interview", "appointment", "schedule"] | |
| if any(k in text_lower for k in meeting_keywords): | |
| schema["topic"] = "string|null" | |
| schema["note"] = "string|null" | |
| if len(schema) <= 1: | |
| schema = { | |
| "amount": "number|null", | |
| "person": "string|null", | |
| "date": "ISO date|null", | |
| "note": "string|null", | |
| } | |
| return schema | |
| # ββ Inference function ββββββββββββββββββββββββββββββββββββββ | |
| def extract(text, custom_schema): | |
| if not text.strip(): | |
| return "", "" | |
| if custom_schema and custom_schema.strip(): | |
| try: | |
| schema = json.loads(custom_schema) | |
| except json.JSONDecodeError: | |
| return "Invalid JSON schema.", "" | |
| else: | |
| schema = auto_schema(text) | |
| schema_str = json.dumps(schema) | |
| prompt = f"### Input: {text}\n### Schema: {schema_str}\n### Output:" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| prompt_length = inputs["input_ids"].shape[1] | |
| # Stop as soon as JSON is complete | |
| stop_criteria = StoppingCriteriaList([ | |
| StopOnJsonComplete(tokenizer, prompt_length) | |
| ]) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| stopping_criteria=stop_criteria, | |
| ) | |
| new_tokens = outputs[0][prompt_length:] | |
| output_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| # Extract just the JSON, ignore any trailing garbage | |
| parsed = extract_json(output_part) | |
| if parsed: | |
| return json.dumps(parsed, indent=2, ensure_ascii=False), json.dumps(schema, indent=2) | |
| else: | |
| return output_part, json.dumps(schema, indent=2) | |
| # ββ Example inputs ββββββββββββββββββββββββββββββββββββββββββ | |
| examples = [ | |
| ["Paid 500 to Ravi for lunch on Jan 5"], | |
| ["Meeting with Sarah at 3pm tomorrow to discuss the project budget of $10,000"], | |
| ["Bought 3 kg of rice from Krishna Stores for 250 rupees on March 10"], | |
| ["Booked a train from Chennai to Bangalore on April 10 for 750 rupees"], | |
| ["Ravi owes me 300 for last week's dinner"], | |
| ["Ordered 2 pizzas and 1 coke from Dominos for 850 rupees"], | |
| ] | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="json-extract") as demo: | |
| gr.Markdown( | |
| """ | |
| # json-extract | |
| Extract structured JSON from natural language text. | |
| Just type a sentence β the model auto-detects the right schema and extracts clean JSON. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="e.g. Paid 500 to Ravi for lunch on Jan 5", | |
| lines=3, | |
| ) | |
| btn = gr.Button("Extract", variant="primary") | |
| with gr.Accordion("Advanced: Custom Schema (optional)", open=False): | |
| schema_input = gr.Textbox( | |
| label="Custom JSON Schema", | |
| placeholder='Leave empty for auto-detect, or enter e.g. {"amount": "number", "person": "string|null"}', | |
| lines=3, | |
| ) | |
| with gr.Column(): | |
| output = gr.Textbox(label="Extracted JSON", lines=10) | |
| detected_schema = gr.Textbox(label="Schema Used", lines=5) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[text_input], | |
| ) | |
| btn.click(fn=extract, inputs=[text_input, schema_input], outputs=[output, detected_schema]) | |
| text_input.submit(fn=extract, inputs=[text_input, schema_input], outputs=[output, detected_schema]) | |
| demo.launch() |