AlexHung29629 commited on
Commit
6ef828c
·
verified ·
1 Parent(s): 37f620a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,27 +1,34 @@
1
  import torch
2
  import spaces
3
  import gradio as gr
4
- from transformers import pipeline
5
- from PIL import Image
6
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
 
7
 
8
  # Load model and processor
9
- model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-screen2words-large", dtype=torch.bfloat16).to("cuda")
 
 
10
  model.eval()
11
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-large")
12
 
13
  # Define the function
14
  @spaces.GPU
15
- def describe_ui(image):
16
- inputs = processor(images=image, text="", return_tensors="pt").to(dtype=torch.bfloat16, device="cuda")
 
 
 
17
  predictions = model.generate(**inputs)
18
  return processor.decode(predictions[0], skip_special_tokens=False)
19
 
20
  # Launch the Gradio interface
21
  gr.Interface(
22
  fn=describe_ui,
23
- inputs=gr.Image(type="pil"),
24
- outputs="text",
 
 
 
25
  title="UI Screen Describer (Pix2Struct)",
26
- description="Upload a screenshot or UI image and get an automatic description powered by Google’s Pix2Struct model."
27
  ).launch()
 
1
  import torch
2
  import spaces
3
  import gradio as gr
 
 
4
  from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
5
+ from PIL import Image
6
 
7
  # Load model and processor
8
+ model = Pix2StructForConditionalGeneration.from_pretrained(
9
+ "google/pix2struct-screen2words-large", dtype=torch.bfloat16
10
+ ).to("cuda")
11
  model.eval()
12
  processor = Pix2StructProcessor.from_pretrained("google/pix2struct-screen2words-large")
13
 
14
  # Define the function
15
  @spaces.GPU
16
+ def describe_ui(image, text):
17
+ # text 為使用者輸入的 prompt,可為空字串
18
+ inputs = processor(images=image, text=text or "", return_tensors="pt").to(
19
+ dtype=torch.bfloat16, device="cuda"
20
+ )
21
  predictions = model.generate(**inputs)
22
  return processor.decode(predictions[0], skip_special_tokens=False)
23
 
24
  # Launch the Gradio interface
25
  gr.Interface(
26
  fn=describe_ui,
27
+ inputs=[
28
+ gr.Image(type="pil", label="Upload UI Screenshot"),
29
+ gr.Textbox(label="Optional prompt / instruction", placeholder="e.g. Describe layout and buttons"),
30
+ ],
31
+ outputs=gr.Textbox(label="Model Output"),
32
  title="UI Screen Describer (Pix2Struct)",
33
+ description="Upload a screenshot or UI image and optionally enter a text prompt. The model (Google Pix2Struct) will generate a detailed description.",
34
  ).launch()