Fred808 commited on
Commit
c6b9676
·
verified ·
1 Parent(s): 27a3bef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForCausalLM
6
+
7
+ # Attempt to install flash-attn
8
+ try:
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
10
+ except subprocess.CalledProcessError as e:
11
+ print(f"Error installing flash-attn: {e}")
12
+ print("Continuing without flash-attn.")
13
+
14
+ # Determine the device to use
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+
18
+ # Load the large model and processor
19
+ try:
20
+ vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
21
+ vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
22
+ except Exception as e:
23
+ print(f"Error loading large model: {e}")
24
+ vision_language_model_large = None
25
+ vision_language_processor_large = None
26
+
27
+ def describe_image(uploaded_image, model_choice):
28
+ """
29
+ Generates a detailed description of the input image using the selected model.
30
+ Args:
31
+ uploaded_image (PIL.Image.Image): The image to describe.
32
+ model_choice (str): The model to use, either "Base" or "Large".
33
+ Returns:
34
+ str: A detailed textual description of the image or an error message.
35
+ """
36
+ if uploaded_image is None:
37
+ return "Please upload an image."
38
+
39
+ if model_choice == "Base":
40
+ if vision_language_model_base is None:
41
+ return "Base model failed to load."
42
+ model = vision_language_model_base
43
+ processor = vision_language_processor_base
44
+ elif model_choice == "Large":
45
+ if vision_language_model_large is None:
46
+ return "Large model failed to load."
47
+ model = vision_language_model_large
48
+ processor = vision_language_processor_large
49
+ else:
50
+ return "Invalid model choice."
51
+
52
+ if not isinstance(uploaded_image, Image.Image):
53
+ uploaded_image = Image.fromarray(uploaded_image)
54
+
55
+ inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device)
56
+ with torch.no_grad():
57
+ generated_ids = model.generate(
58
+ input_ids=inputs["input_ids"],
59
+ pixel_values=inputs["pixel_values"],
60
+ max_new_tokens=1024,
61
+ early_stopping=False,
62
+ do_sample=False,
63
+ num_beams=3,
64
+ )
65
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
66
+ processed_description = processor.post_process_generation(
67
+ generated_text,
68
+ task="<MORE_DETAILED_CAPTION>",
69
+ image_size=(uploaded_image.width, uploaded_image.height)
70
+ )
71
+ image_description = processed_description["<MORE_DETAILED_CAPTION>"]
72
+ print("\nImage description generated!:", image_description)
73
+ return image_description
74
+
75
+ # Description for the interface
76
+ description = "Select the model to use for generating the image description. 'Base' is smaller and faster, while 'Large' is more accurate but slower."
77
+ if device == "cpu":
78
+ description += " Note: Running on CPU, which may be slow for large models."
79
+
80
+ # Create the Gradio interface
81
+ image_description_interface = gr.Interface(
82
+ fn=describe_image,
83
+ inputs=[
84
+ gr.Image(label="Upload Image", type="pil"),
85
+ gr.Radio(["Base", "Large"], label="Model Choice", value="Base")
86
+ ],
87
+ outputs=gr.Textbox(label="Generated Caption", lines=4, show_copy_button=True),
88
+ live=False,
89
+ title="Florence-2 Models Image Captions",
90
+ description=description
91
+ )
92
+
93
+ # Launch the interface
94
+ image_description_interface.launch(debug=True, ssr_mode=False)