suneeldk commited on
Commit
0d6ff7a
Β·
verified Β·
1 Parent(s): be027a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -13
app.py CHANGED
@@ -1,17 +1,17 @@
1
  import gradio as gr
2
  import json
 
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  from peft import PeftModel
7
 
8
  # ── Load model once at startup ──────────────────────────────
9
  BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
10
- LORA_MODEL = "suneeldk/json-extract"
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL)
13
 
14
- # Load in 4-bit for faster inference
15
  bnb_config = BitsAndBytesConfig(
16
  load_in_4bit=True,
17
  bnb_4bit_quant_type="nf4",
@@ -25,10 +25,54 @@ base_model = AutoModelForCausalLM.from_pretrained(
25
  )
26
 
27
  model = PeftModel.from_pretrained(base_model, LORA_MODEL)
28
- model = model.merge_and_unload() # Merge LoRA into base β€” removes adapter overhead
29
  model.eval()
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # ── Auto-detect schema from text ────────────────────────────
33
  def auto_schema(text):
34
  text_lower = text.lower()
@@ -60,8 +104,8 @@ def auto_schema(text):
60
  schema["item"] = "string|null"
61
  schema["quantity"] = "string|null"
62
 
63
- location_keywords = ["from", "to", "at", "in", "store", "shop", "restaurant",
64
- "station", "airport", "hotel", "office", "train", "flight", "bus"]
65
  if any(k in text_lower for k in location_keywords):
66
  schema["location"] = "string|null"
67
 
@@ -107,23 +151,30 @@ def extract(text, custom_schema):
107
  schema_str = json.dumps(schema)
108
  prompt = f"### Input: {text}\n### Schema: {schema_str}\n### Output:"
109
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
110
 
111
  with torch.no_grad():
112
  outputs = model.generate(
113
  **inputs,
114
- max_new_tokens=128, # JSON output is short, no need for 512
115
- do_sample=False, # Greedy decoding β€” faster than sampling
116
  pad_token_id=tokenizer.eos_token_id,
 
117
  )
118
 
119
- # Decode only the new tokens, skip the prompt
120
- new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
121
  output_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
122
 
123
- try:
124
- parsed = json.loads(output_part)
 
125
  return json.dumps(parsed, indent=2, ensure_ascii=False), json.dumps(schema, indent=2)
126
- except json.JSONDecodeError:
127
  return output_part, json.dumps(schema, indent=2)
128
 
129
 
 
1
  import gradio as gr
2
  import json
3
+ import re
4
  import spaces
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
7
  from peft import PeftModel
8
 
9
  # ── Load model once at startup ──────────────────────────────
10
  BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
11
+ LORA_MODEL = "suneeldk/json-extract" # ← change this
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(LORA_MODEL)
14
 
 
15
  bnb_config = BitsAndBytesConfig(
16
  load_in_4bit=True,
17
  bnb_4bit_quant_type="nf4",
 
25
  )
26
 
27
  model = PeftModel.from_pretrained(base_model, LORA_MODEL)
28
+ model = model.merge_and_unload()
29
  model.eval()
30
 
31
 
32
+ # ── Stop generation when JSON is complete ───────────────────
33
+ class StopOnJsonComplete(StoppingCriteria):
34
+ """Stop generating once we have a complete JSON object."""
35
+ def __init__(self, tokenizer, prompt_length):
36
+ self.tokenizer = tokenizer
37
+ self.prompt_length = prompt_length
38
+
39
+ def __call__(self, input_ids, scores, **kwargs):
40
+ new_tokens = input_ids[0][self.prompt_length:]
41
+ text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
42
+ if not text.startswith("{"):
43
+ return False
44
+ # Count braces to detect complete JSON
45
+ depth = 0
46
+ for char in text:
47
+ if char == "{":
48
+ depth += 1
49
+ elif char == "}":
50
+ depth -= 1
51
+ if depth == 0:
52
+ return True # JSON object is complete, stop!
53
+ return False
54
+
55
+
56
+ # ── Extract first valid JSON from text ──────────────────────
57
+ def extract_json(text):
58
+ """Find and return the first complete JSON object in text."""
59
+ depth = 0
60
+ start = None
61
+ for i, char in enumerate(text):
62
+ if char == "{":
63
+ if start is None:
64
+ start = i
65
+ depth += 1
66
+ elif char == "}":
67
+ depth -= 1
68
+ if depth == 0 and start is not None:
69
+ try:
70
+ return json.loads(text[start:i + 1])
71
+ except json.JSONDecodeError:
72
+ start = None
73
+ return None
74
+
75
+
76
  # ── Auto-detect schema from text ────────────────────────────
77
  def auto_schema(text):
78
  text_lower = text.lower()
 
104
  schema["item"] = "string|null"
105
  schema["quantity"] = "string|null"
106
 
107
+ location_keywords = ["store", "shop", "restaurant", "station", "airport",
108
+ "hotel", "office"]
109
  if any(k in text_lower for k in location_keywords):
110
  schema["location"] = "string|null"
111
 
 
151
  schema_str = json.dumps(schema)
152
  prompt = f"### Input: {text}\n### Schema: {schema_str}\n### Output:"
153
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
154
+ prompt_length = inputs["input_ids"].shape[1]
155
+
156
+ # Stop as soon as JSON is complete
157
+ stop_criteria = StoppingCriteriaList([
158
+ StopOnJsonComplete(tokenizer, prompt_length)
159
+ ])
160
 
161
  with torch.no_grad():
162
  outputs = model.generate(
163
  **inputs,
164
+ max_new_tokens=128,
165
+ do_sample=False,
166
  pad_token_id=tokenizer.eos_token_id,
167
+ stopping_criteria=stop_criteria,
168
  )
169
 
170
+ new_tokens = outputs[0][prompt_length:]
 
171
  output_part = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
172
 
173
+ # Extract just the JSON, ignore any trailing garbage
174
+ parsed = extract_json(output_part)
175
+ if parsed:
176
  return json.dumps(parsed, indent=2, ensure_ascii=False), json.dumps(schema, indent=2)
177
+ else:
178
  return output_part, json.dumps(schema, indent=2)
179
 
180