|
|
import torch |
|
|
from transformers import pipeline |
|
|
import gradio as gr |
|
|
import json |
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
model_id = "dylanhogg/gnaf-structured-address-v0.2-712c28b-20251019-170453" |
|
|
|
|
|
max_new_tokens = 256 |
|
|
do_sample = False |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = "cuda" |
|
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
device_map = "auto" |
|
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = "mps" |
|
|
dtype = torch.float16 |
|
|
device_map = None |
|
|
else: |
|
|
device = "cpu" |
|
|
dtype = torch.float32 |
|
|
device_map = None |
|
|
|
|
|
print(f"Device settings: {device=}, {dtype=}, {device_map=}") |
|
|
|
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model=model_id, |
|
|
dtype=dtype, |
|
|
device_map=device_map, |
|
|
) |
|
|
|
|
|
print(f"Model {pipe.model=}") |
|
|
print(f"Model config {pipe.model.config=}") |
|
|
print("Ready to parse addresses!") |
|
|
|
|
|
|
|
|
def parse_address(user_address: str) -> tuple[str, str, str]: |
|
|
"""Parse address and return both raw response, JSON, and timing info""" |
|
|
user_content = f"Translate a text address into structured json.\n{user_address}" |
|
|
messages = [ |
|
|
{"role": "system", "content": "Translate a text address into structured json."}, |
|
|
{"role": "user", "content": user_content}, |
|
|
] |
|
|
|
|
|
inference_start = time.time() |
|
|
outputs = pipe( |
|
|
messages, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=do_sample, |
|
|
) |
|
|
inference_time = time.time() - inference_start |
|
|
|
|
|
response = outputs[0]["generated_text"] |
|
|
last_content = response[-1]["content"] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
start = last_content.find("{") |
|
|
end = last_content.rfind("}") + 1 |
|
|
if start != -1 and end > start: |
|
|
json_str = last_content[start:end] |
|
|
parsed = json.loads(json_str) |
|
|
formatted_json = json.dumps(parsed, indent=2) |
|
|
else: |
|
|
formatted_json = "No JSON found in response" |
|
|
except Exception as e: |
|
|
formatted_json = f"Error parsing JSON: {str(e)}\n\nRaw output:\n{last_content}" |
|
|
|
|
|
timing_info = f"{inference_time:.3f}s (on {device} device)" |
|
|
return formatted_json, last_content, timing_info |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Address Parser") as demo: |
|
|
gr.Markdown("# ๐ Structured Address Parser") |
|
|
gr.Markdown("Extracts structured address information from text.") |
|
|
gr.Markdown( |
|
|
"Structured output is JSON format with Australian [G-NAF](https://docs.geoscape.com.au/projects/gnaf_desc/en/stable/overview.html) (Geocoded National Address File) fields." |
|
|
) |
|
|
gr.Markdown(f"Model: [{model_id}](https://huggingface.co/{model_id})") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_text = gr.Textbox( |
|
|
label="๐ Input Address", placeholder="Enter an address...", value="48a Pirrama Rd Pyrmont NSW 2009" |
|
|
) |
|
|
submit_btn = gr.Button("Parse Address", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
took_output = gr.Textbox(label="โฑ๏ธ Inference Time", interactive=False, max_lines=1) |
|
|
json_output = gr.Textbox(label="๐ Structured JSON", interactive=False, lines=10) |
|
|
raw_output = gr.Textbox(label="๐ Raw Model Output", interactive=False, lines=3) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"48a Pirrama Rd Pyrmont NSW 2009", |
|
|
"Floor 3, 152-156 Clarence St, Sydney NSW 2000", |
|
|
"Aptt 16, 400 Bondi Rd, Bondi NSW 2026", |
|
|
"Unit 18/14-18 Flood St, Bondi, NSW 2026", |
|
|
"Lvl 15/333 George St Sydney NSW 2000", |
|
|
"Check out: 44 Ulm St, Maroubra NSW 2035", |
|
|
"44 Ulm St, Maroubra NSW 2035 is where it's at!", |
|
|
"Have you been to level 3 123 George St, Sydney NSW 2000?", |
|
|
"Suite 4, 27-31 King William St, Adelaide SA 5000", |
|
|
"Unit 2B 15 O'Connell St, Parramatta NSW 2150", |
|
|
"Lvl 9, 100 Creek St, Brisbane QLD 4000", |
|
|
"Shop 5/42 Jetty Rd, Glenelg SA 5045", |
|
|
"Lot 12, 89 Wattle Grove Rd, Mulgoa NSW 2745", |
|
|
"Meet me at 20 Smith St, Collingwood VIC 3066 later today!", |
|
|
"We just moved into 7/145 Marine Parade, Coolangatta QLD 4225.", |
|
|
"Drop the package at level 2, 88 Pitt St Sydney NSW 2000 please.", |
|
|
"I think the cafรฉ is near 11-13 Lygon St, Carlton VIC 3053.", |
|
|
"Could you send the invoice to our office at 120 Northbourne Ave Canberra ACT 2601?", |
|
|
"PO Box 42, Dubbo NSW 2830", |
|
|
"Head to 9/250 St Kilda Rd, Melbourne VIC 3006 for the meeting.", |
|
|
"Unit 5B, 18-22 Manning St, South Brisbane QLD 4101", |
|
|
"Corner of King St and Brown St, Newtown NSW 2042", |
|
|
"Flat 3, 77 Princes Hwy, Dandenong VIC 3175", |
|
|
"We're staying at 22 Beach Rd, Batemans Bay NSW 2536 this weekend.", |
|
|
"Drop it off at 6/89 Murray St, Perth WA 6000.", |
|
|
"Warehouse 12, 45 Industrial Dr, Mayfield NSW 2304", |
|
|
"Visit us on level 4, 11 Elizabeth St Hobart TAS 7000.", |
|
|
"The Airbnb is at 3 Hilltop Ave, Burleigh Heads QLD 4220.", |
|
|
], |
|
|
inputs=input_text, |
|
|
) |
|
|
|
|
|
submit_btn.click(fn=parse_address, inputs=input_text, outputs=[json_output, raw_output, took_output]) |
|
|
input_text.submit(fn=parse_address, inputs=input_text, outputs=[json_output, raw_output, took_output]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Launching app...") |
|
|
demo.launch() |
|
|
print("Done.") |
|
|
|