EYEDOL commited on
Commit
019b165
·
verified ·
1 Parent(s): 2e2f9b7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import LlavaForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
4
+ from PIL import Image
5
+
6
+ # Configuration
7
+ MODEL_ID = "llava-hf/llava-1.5-7b-hf"
8
+
9
+ print(f"Loading {MODEL_ID}... This may take a few minutes depending on your internet connection.")
10
+
11
+ # 1. Load Model with Quantization (to save GPU memory)
12
+ # We use 4-bit quantization so this can run on consumer GPUs (approx 6-8GB VRAM required)
13
+ quantization_config = BitsAndBytesConfig(
14
+ load_in_4bit=True,
15
+ bnb_4bit_compute_dtype=torch.float16
16
+ )
17
+
18
+ try:
19
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
20
+ model = LlavaForConditionalGeneration.from_pretrained(
21
+ MODEL_ID,
22
+ quantization_config=quantization_config,
23
+ device_map="auto"
24
+ )
25
+ print("Model loaded successfully!")
26
+ except Exception as e:
27
+ print(f"Error loading model: {e}")
28
+ print("Ensure you have a GPU available and 'bitsandbytes' installed.")
29
+ exit()
30
+
31
+ def format_prompt(image, history, message):
32
+ """
33
+ Formats the conversation history and new message into the template LLaVA expects.
34
+ Standard LLaVA 1.5 format: USER: <image>\n<prompt>\nASSISTANT:
35
+ """
36
+ prompt = ""
37
+
38
+ # Use the conversation history to build context (simplified for single-turn image focus)
39
+ # Note: Multi-turn chat with LLaVA can get heavy on context length,
40
+ # so we focus primarily on the current question + image.
41
+
42
+ prompt = f"USER: <image>\n{message}\nASSISTANT:"
43
+ return prompt
44
+
45
+ def chat_response(message, history, image_input):
46
+ """
47
+ Main generation function called by Gradio.
48
+ """
49
+ if image_input is None:
50
+ return "Please upload an image first to chat about it!"
51
+
52
+ # 1. Prepare text prompt
53
+ prompt_text = format_prompt(image_input, history, message)
54
+
55
+ # 2. Process inputs (Image + Text)
56
+ # Converting image to RGB is important as some PNGs have alpha channels
57
+ image = image_input.convert("RGB")
58
+
59
+ inputs = processor(text=prompt_text, images=image, return_tensors="pt").to(model.device)
60
+
61
+ # 3. Generate Response
62
+ # max_new_tokens determines how long the answer can be
63
+ output = model.generate(
64
+ **inputs,
65
+ max_new_tokens=200,
66
+ do_sample=True,
67
+ temperature=0.7,
68
+ top_p=0.9
69
+ )
70
+
71
+ # 4. Decode output
72
+ decoded_output = processor.batch_decode(output, skip_special_tokens=True)[0]
73
+
74
+ # The raw output contains the prompt, so we strip it out to get just the assistant's reply
75
+ # The prompt format is "USER: ... ASSISTANT:", so we split by ASSISTANT:
76
+ response = decoded_output.split("ASSISTANT:")[-1].strip()
77
+
78
+ return response
79
+
80
+ # --- Gradio UI Setup ---
81
+
82
+ with gr.Blocks(title="LLaVA Image Chat", theme=gr.themes.Soft()) as demo:
83
+ gr.Markdown("# 🌋 LLaVA: Chat with Images")
84
+ gr.Markdown("Upload an image and ask questions about it using the LLaVA 1.5 Model.")
85
+
86
+ with gr.Row():
87
+ with gr.Column(scale=1):
88
+ image_box = gr.Image(type="pil", label="Upload Image")
89
+
90
+ with gr.Column(scale=2):
91
+ chatbot = gr.ChatInterface(
92
+ fn=chat_response,
93
+ additional_inputs=[image_box],
94
+ title="Chat",
95
+ description="Ask about the uploaded image.",
96
+ examples=["What is in this image?", "Describe the colors.", "Can you read the text in the image?"],
97
+ )
98
+
99
+ if __name__ == "__main__":
100
+ # queue() is required for generator/streaming interactions in some environments
101
+ demo.queue().launch()