dylanhogg commited on
Commit
020f360
·
1 Parent(s): 462b9e0

Handle device and dtype better

Browse files
Files changed (2) hide show
  1. Makefile +1 -1
  2. app.py +27 -9
Makefile CHANGED
@@ -10,7 +10,7 @@ clean:
10
  rm -rf .venv
11
 
12
  run:
13
- source .venv/bin/activate ; python address_parser.py
14
 
15
  black-check:
16
  source .venv/bin/activate ; black . --check --verbose --line-length 120
 
10
  rm -rf .venv
11
 
12
  run:
13
+ source .venv/bin/activate ; python app.py
14
 
15
  black-check:
16
  source .venv/bin/activate ; black . --check --verbose --line-length 120
app.py CHANGED
@@ -3,16 +3,36 @@ from transformers import pipeline
3
  import gradio as gr
4
  import json
5
 
6
- # Initialize model pipeline
7
  model_id = "dylanhogg/gnaf-structured-address-v0.1-75a1791-20250921-063650"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  pipe = pipeline(
10
  "text-generation",
11
  model=model_id,
12
- dtype=torch.bfloat16,
13
- device_map="auto",
14
  )
15
 
 
 
 
16
 
17
  def parse_address(user_address: str) -> tuple[str, str]:
18
  """Parse address and return both raw response and JSON"""
@@ -24,8 +44,8 @@ def parse_address(user_address: str) -> tuple[str, str]:
24
 
25
  outputs = pipe(
26
  messages,
27
- max_new_tokens=256,
28
- do_sample=False,
29
  )
30
 
31
  response = outputs[0]["generated_text"]
@@ -48,7 +68,6 @@ def parse_address(user_address: str) -> tuple[str, str]:
48
  return formatted_json, last_content
49
 
50
 
51
- # Create Gradio interface
52
  with gr.Blocks(title="Address Parser") as demo:
53
  gr.Markdown("# 🏠 Structured Address Parser")
54
  gr.Markdown("This model converts text addresses into structured JSON format.")
@@ -64,7 +83,6 @@ with gr.Blocks(title="Address Parser") as demo:
64
  json_output = gr.Textbox(label="Structured JSON", interactive=False, lines=10)
65
  raw_output = gr.Textbox(label="Raw Model Output", interactive=False, lines=5)
66
 
67
- # Examples
68
  gr.Examples(
69
  examples=[
70
  "48a Pirrama Rd Pyrmont NSW 2009",
@@ -99,10 +117,10 @@ with gr.Blocks(title="Address Parser") as demo:
99
  inputs=input_text,
100
  )
101
 
102
- # Handle events
103
  submit_btn.click(fn=parse_address, inputs=input_text, outputs=[json_output, raw_output])
104
  input_text.submit(fn=parse_address, inputs=input_text, outputs=[json_output, raw_output])
105
 
106
- # Launch the app
107
  if __name__ == "__main__":
 
108
  demo.launch()
 
 
3
  import gradio as gr
4
  import json
5
 
 
6
  model_id = "dylanhogg/gnaf-structured-address-v0.1-75a1791-20250921-063650"
7
+ max_new_tokens = 256
8
+ do_sample = False
9
+
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
13
+ device_map = "auto"
14
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
15
+ elif torch.backends.mps.is_available():
16
+ device = "mps"
17
+ dtype = torch.float16
18
+ device_map = None
19
+ else:
20
+ device = "cpu"
21
+ dtype = torch.float32
22
+ device_map = None
23
+
24
+ print(f"Device settings: {device=}, {dtype=}, {device_map=}")
25
 
26
  pipe = pipeline(
27
  "text-generation",
28
  model=model_id,
29
+ dtype=dtype,
30
+ device_map=device_map,
31
  )
32
 
33
+ print(f"Model {pipe.model=}")
34
+ print(f"Model config {pipe.model.config=}")
35
+
36
 
37
  def parse_address(user_address: str) -> tuple[str, str]:
38
  """Parse address and return both raw response and JSON"""
 
44
 
45
  outputs = pipe(
46
  messages,
47
+ max_new_tokens=max_new_tokens,
48
+ do_sample=do_sample,
49
  )
50
 
51
  response = outputs[0]["generated_text"]
 
68
  return formatted_json, last_content
69
 
70
 
 
71
  with gr.Blocks(title="Address Parser") as demo:
72
  gr.Markdown("# 🏠 Structured Address Parser")
73
  gr.Markdown("This model converts text addresses into structured JSON format.")
 
83
  json_output = gr.Textbox(label="Structured JSON", interactive=False, lines=10)
84
  raw_output = gr.Textbox(label="Raw Model Output", interactive=False, lines=5)
85
 
 
86
  gr.Examples(
87
  examples=[
88
  "48a Pirrama Rd Pyrmont NSW 2009",
 
117
  inputs=input_text,
118
  )
119
 
 
120
  submit_btn.click(fn=parse_address, inputs=input_text, outputs=[json_output, raw_output])
121
  input_text.submit(fn=parse_address, inputs=input_text, outputs=[json_output, raw_output])
122
 
 
123
  if __name__ == "__main__":
124
+ print("Launching app...")
125
  demo.launch()
126
+ print("Done.")