EYEDOL commited on
Commit
3575d8b
·
verified ·
1 Parent(s): 07bac4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -58
app.py CHANGED
@@ -1,47 +1,31 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import LlavaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
4
- from PIL import Image
5
 
6
  # Configuration
7
- MODEL_ID = "llava-hf/llava-1.5-7b-hf"
8
 
9
- print(f"Loading {MODEL_ID}... This may take a few minutes depending on your internet connection.")
10
-
11
- # 1. Load Model with Quantization (to save GPU memory)
12
- # We use 4-bit quantization so this can run on consumer GPUs (approx 6-8GB VRAM required)
13
- quantization_config = BitsAndBytesConfig(
14
- load_in_4bit=True,
15
- bnb_4bit_compute_dtype=torch.float16
16
- )
17
 
 
 
 
18
  try:
19
- processor = AutoProcessor.from_pretrained(MODEL_ID)
20
- model = LlavaForConditionalGeneration.from_pretrained(
21
- MODEL_ID,
22
- quantization_config=quantization_config,
23
- device_map="auto"
24
  )
 
 
 
25
  print("Model loaded successfully!")
26
  except Exception as e:
27
  print(f"Error loading model: {e}")
28
- print("Ensure you have a GPU available and 'bitsandbytes' installed.")
29
  exit()
30
 
31
- def format_prompt(image, history, message):
32
- """
33
- Formats the conversation history and new message into the template LLaVA expects.
34
- Standard LLaVA 1.5 format: USER: <image>\n<prompt>\nASSISTANT:
35
- """
36
- prompt = ""
37
-
38
- # Use the conversation history to build context (simplified for single-turn image focus)
39
- # Note: Multi-turn chat with LLaVA can get heavy on context length,
40
- # so we focus primarily on the current question + image.
41
-
42
- prompt = f"USER: <image>\n{message}\nASSISTANT:"
43
- return prompt
44
-
45
  def chat_response(message, history, image_input):
46
  """
47
  Main generation function called by Gradio.
@@ -49,42 +33,66 @@ def chat_response(message, history, image_input):
49
  if image_input is None:
50
  return "Please upload an image first to chat about it!"
51
 
52
- # 1. Prepare text prompt
53
- prompt_text = format_prompt(image_input, history, message)
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # 2. Process inputs (Image + Text)
56
- # Converting image to RGB is important as some PNGs have alpha channels
57
- try:
58
- image = image_input.convert("RGB")
59
- except Exception:
60
- return "Error processing image. Please ensure it is a valid image file."
 
 
 
 
 
 
 
 
 
61
 
62
- inputs = processor(text=prompt_text, images=image, return_tensors="pt").to(model.device)
 
63
 
64
- # 3. Generate Response
65
- # max_new_tokens determines how long the answer can be
66
- output = model.generate(
67
- **inputs,
68
  max_new_tokens=200,
69
  do_sample=True,
70
  temperature=0.7,
71
  top_p=0.9
72
  )
73
 
74
- # 4. Decode output
75
- decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0]
76
-
77
- # The raw output contains the prompt, so we strip it out to get just the assistant's reply
78
- # The prompt format is "USER: ... ASSISTANT:", so we split by ASSISTANT:
79
- response = decoded_output.split("ASSISTANT:")[-1].strip()
80
 
 
 
 
 
81
  return response
82
 
83
  # --- Gradio UI Setup ---
84
-
85
- with gr.Blocks(title="LLaVA Image Chat", theme=gr.themes.Soft()) as demo:
86
- gr.Markdown("# 🌋 LLaVA: Chat with Images")
87
- gr.Markdown("Upload an image and ask questions about it using the LLaVA 1.5 Model.")
88
 
89
  with gr.Row():
90
  with gr.Column(scale=1):
@@ -96,14 +104,12 @@ with gr.Blocks(title="LLaVA Image Chat", theme=gr.themes.Soft()) as demo:
96
  additional_inputs=[image_box],
97
  title="Chat",
98
  description="Ask about the uploaded image.",
99
- # Examples must match the inputs: [text_message, image_input_value]
100
  examples=[
101
  ["What is in this image?", None],
102
- ["Describe the colors.", None],
103
- ["Can you read the text in the image?", None],
104
  ],
105
  )
106
 
107
  if __name__ == "__main__":
108
- # queue() is required for generator/streaming interactions in some environments
109
  demo.queue().launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
 
6
  # Configuration
7
+ MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
8
 
9
+ print(f"Loading {MODEL_ID}...")
 
 
 
 
 
 
 
10
 
11
+ # 1. Load Model
12
+ # We use bfloat16 (half precision) which is faster than 4-bit for small models
13
+ # and fits easily in 16GB or even 8GB VRAM.
14
  try:
15
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
16
+ MODEL_ID,
17
+ torch_dtype=torch.bfloat16,
18
+ device_map="auto",
 
19
  )
20
+
21
+ # The min_pixels and max_pixels arguments help control resolution for speed
22
+ processor = AutoProcessor.from_pretrained(MODEL_ID, min_pixels=256*28*28, max_pixels=1280*28*28)
23
  print("Model loaded successfully!")
24
  except Exception as e:
25
  print(f"Error loading model: {e}")
26
+ print("Ensure you have a GPU available.")
27
  exit()
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def chat_response(message, history, image_input):
30
  """
31
  Main generation function called by Gradio.
 
33
  if image_input is None:
34
  return "Please upload an image first to chat about it!"
35
 
36
+ # 2. Prepare the messages for Qwen2-VL
37
+ # Qwen expects a specific format: a list of messages with specific 'type' keys
38
+ messages = [
39
+ {
40
+ "role": "user",
41
+ "content": [
42
+ {
43
+ "type": "image",
44
+ "image": image_input, # Pass the PIL image directly
45
+ },
46
+ {"type": "text", "text": message},
47
+ ],
48
+ }
49
+ ]
50
 
51
+ # 3. Process inputs
52
+ # qwen_vl_utils helps process the image and text into tensors
53
+ text = processor.apply_chat_template(
54
+ messages, tokenize=False, add_generation_prompt=True
55
+ )
56
+
57
+ image_inputs, video_inputs = process_vision_info(messages)
58
+
59
+ inputs = processor(
60
+ text=[text],
61
+ images=image_inputs,
62
+ videos=video_inputs,
63
+ padding=True,
64
+ return_tensors="pt",
65
+ )
66
 
67
+ # Move inputs to the same device as the model
68
+ inputs = inputs.to(model.device)
69
 
70
+ # 4. Generate Response
71
+ # We limit max_new_tokens to 200 for speed
72
+ generated_ids = model.generate(
73
+ **inputs,
74
  max_new_tokens=200,
75
  do_sample=True,
76
  temperature=0.7,
77
  top_p=0.9
78
  )
79
 
80
+ # 5. Decode output
81
+ # We trim the input tokens from the output to get only the new response
82
+ generated_ids_trimmed = [
83
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
84
+ ]
 
85
 
86
+ response = processor.batch_decode(
87
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
88
+ )[0]
89
+
90
  return response
91
 
92
  # --- Gradio UI Setup ---
93
+ with gr.Blocks(title="Qwen2-VL Chat", theme=gr.themes.Soft()) as demo:
94
+ gr.Markdown("# 🚀 Qwen2-VL-2B: Fast Image Chat")
95
+ gr.Markdown("Upload an image and ask questions. This 2B model is significantly faster than LLaVA-7B.")
 
96
 
97
  with gr.Row():
98
  with gr.Column(scale=1):
 
104
  additional_inputs=[image_box],
105
  title="Chat",
106
  description="Ask about the uploaded image.",
 
107
  examples=[
108
  ["What is in this image?", None],
109
+ ["Describe the lighting.", None],
110
+ ["Read the text in the image.", None],
111
  ],
112
  )
113
 
114
  if __name__ == "__main__":
 
115
  demo.queue().launch()