Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import json
|
|
|
|
| 3 |
import spaces
|
| 4 |
import torch
|
| 5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 6 |
from peft import PeftModel
|
| 7 |
|
| 8 |
# ββ Load model once at startup ββββββββββββββββββββββββββββββ
|
| 9 |
BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 10 |
-
LORA_MODEL = "suneeldk/json-extract"
|
| 11 |
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL)
|
| 13 |
|
| 14 |
-
# Load in 4-bit for faster inference
|
| 15 |
bnb_config = BitsAndBytesConfig(
|
| 16 |
load_in_4bit=True,
|
| 17 |
bnb_4bit_quant_type="nf4",
|
|
@@ -25,10 +25,54 @@ base_model = AutoModelForCausalLM.from_pretrained(
|
|
| 25 |
)
|
| 26 |
|
| 27 |
model = PeftModel.from_pretrained(base_model, LORA_MODEL)
|
| 28 |
-
model = model.merge_and_unload()
|
| 29 |
model.eval()
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# ββ Auto-detect schema from text ββββββββββββββββββββββββββββ
|
| 33 |
def auto_schema(text):
|
| 34 |
text_lower = text.lower()
|
|
@@ -60,8 +104,8 @@ def auto_schema(text):
|
|
| 60 |
schema["item"] = "string|null"
|
| 61 |
schema["quantity"] = "string|null"
|
| 62 |
|
| 63 |
-
location_keywords = ["
|
| 64 |
-
"
|
| 65 |
if any(k in text_lower for k in location_keywords):
|
| 66 |
schema["location"] = "string|null"
|
| 67 |
|
|
@@ -107,23 +151,30 @@ def extract(text, custom_schema):
|
|
| 107 |
schema_str = json.dumps(schema)
|
| 108 |
prompt = f"### Input: {text}\n### Schema: {schema_str}\n### Output:"
|
| 109 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
with torch.no_grad():
|
| 112 |
outputs = model.generate(
|
| 113 |
**inputs,
|
| 114 |
-
max_new_tokens=128,
|
| 115 |
-
do_sample=False,
|
| 116 |
pad_token_id=tokenizer.eos_token_id,
|
|
|
|
| 117 |
)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
| 121 |
output_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
|
|
|
| 125 |
return json.dumps(parsed, indent=2, ensure_ascii=False), json.dumps(schema, indent=2)
|
| 126 |
-
|
| 127 |
return output_part, json.dumps(schema, indent=2)
|
| 128 |
|
| 129 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import json
|
| 3 |
+
import re
|
| 4 |
import spaces
|
| 5 |
import torch
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
|
| 7 |
from peft import PeftModel
|
| 8 |
|
| 9 |
# ββ Load model once at startup ββββββββββββββββββββββββββββββ
|
| 10 |
BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 11 |
+
LORA_MODEL = "suneeldk/json-extract" # β change this
|
| 12 |
|
| 13 |
tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL)
|
| 14 |
|
|
|
|
| 15 |
bnb_config = BitsAndBytesConfig(
|
| 16 |
load_in_4bit=True,
|
| 17 |
bnb_4bit_quant_type="nf4",
|
|
|
|
| 25 |
)
|
| 26 |
|
| 27 |
model = PeftModel.from_pretrained(base_model, LORA_MODEL)
|
| 28 |
+
model = model.merge_and_unload()
|
| 29 |
model.eval()
|
| 30 |
|
| 31 |
|
| 32 |
+
# ββ Stop generation when JSON is complete βββββββββββββββββββ
|
| 33 |
+
class StopOnJsonComplete(StoppingCriteria):
|
| 34 |
+
"""Stop generating once we have a complete JSON object."""
|
| 35 |
+
def __init__(self, tokenizer, prompt_length):
|
| 36 |
+
self.tokenizer = tokenizer
|
| 37 |
+
self.prompt_length = prompt_length
|
| 38 |
+
|
| 39 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 40 |
+
new_tokens = input_ids[0][self.prompt_length:]
|
| 41 |
+
text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 42 |
+
if not text.startswith("{"):
|
| 43 |
+
return False
|
| 44 |
+
# Count braces to detect complete JSON
|
| 45 |
+
depth = 0
|
| 46 |
+
for char in text:
|
| 47 |
+
if char == "{":
|
| 48 |
+
depth += 1
|
| 49 |
+
elif char == "}":
|
| 50 |
+
depth -= 1
|
| 51 |
+
if depth == 0:
|
| 52 |
+
return True # JSON object is complete, stop!
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ββ Extract first valid JSON from text ββββββββββββββββββββββ
|
| 57 |
+
def extract_json(text):
|
| 58 |
+
"""Find and return the first complete JSON object in text."""
|
| 59 |
+
depth = 0
|
| 60 |
+
start = None
|
| 61 |
+
for i, char in enumerate(text):
|
| 62 |
+
if char == "{":
|
| 63 |
+
if start is None:
|
| 64 |
+
start = i
|
| 65 |
+
depth += 1
|
| 66 |
+
elif char == "}":
|
| 67 |
+
depth -= 1
|
| 68 |
+
if depth == 0 and start is not None:
|
| 69 |
+
try:
|
| 70 |
+
return json.loads(text[start:i + 1])
|
| 71 |
+
except json.JSONDecodeError:
|
| 72 |
+
start = None
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
# ββ Auto-detect schema from text ββββββββββββββββββββββββββββ
|
| 77 |
def auto_schema(text):
|
| 78 |
text_lower = text.lower()
|
|
|
|
| 104 |
schema["item"] = "string|null"
|
| 105 |
schema["quantity"] = "string|null"
|
| 106 |
|
| 107 |
+
location_keywords = ["store", "shop", "restaurant", "station", "airport",
|
| 108 |
+
"hotel", "office"]
|
| 109 |
if any(k in text_lower for k in location_keywords):
|
| 110 |
schema["location"] = "string|null"
|
| 111 |
|
|
|
|
| 151 |
schema_str = json.dumps(schema)
|
| 152 |
prompt = f"### Input: {text}\n### Schema: {schema_str}\n### Output:"
|
| 153 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 154 |
+
prompt_length = inputs["input_ids"].shape[1]
|
| 155 |
+
|
| 156 |
+
# Stop as soon as JSON is complete
|
| 157 |
+
stop_criteria = StoppingCriteriaList([
|
| 158 |
+
StopOnJsonComplete(tokenizer, prompt_length)
|
| 159 |
+
])
|
| 160 |
|
| 161 |
with torch.no_grad():
|
| 162 |
outputs = model.generate(
|
| 163 |
**inputs,
|
| 164 |
+
max_new_tokens=128,
|
| 165 |
+
do_sample=False,
|
| 166 |
pad_token_id=tokenizer.eos_token_id,
|
| 167 |
+
stopping_criteria=stop_criteria,
|
| 168 |
)
|
| 169 |
|
| 170 |
+
new_tokens = outputs[0][prompt_length:]
|
|
|
|
| 171 |
output_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
|
| 172 |
|
| 173 |
+
# Extract just the JSON, ignore any trailing garbage
|
| 174 |
+
parsed = extract_json(output_part)
|
| 175 |
+
if parsed:
|
| 176 |
return json.dumps(parsed, indent=2, ensure_ascii=False), json.dumps(schema, indent=2)
|
| 177 |
+
else:
|
| 178 |
return output_part, json.dumps(schema, indent=2)
|
| 179 |
|
| 180 |
|