achokshi commited on
Commit
6adb6a2
Β·
1 Parent(s): 21e41fc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ from io import BytesIO
3
+
4
+ import requests
5
+ import torch
6
+ from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
7
+ from llava.conversation import SeparatorStyle, conv_templates
8
+ from llava.mm_utils import (
9
+ KeywordsStoppingCriteria,
10
+ get_model_name_from_path,
11
+ process_images,
12
+ tokenizer_image_token,
13
+ )
14
+ from llava.model.builder import load_pretrained_model
15
+ from llava.utils import disable_torch_init
16
+ from PIL import Image
17
+
18
+
19
+ disable_torch_init()
20
+
21
+ MODEL = "4bit/llava-v1.5-13b-3GB"
22
+ model_name = get_model_name_from_path(MODEL)
23
+ model_name
24
+
25
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
26
+ model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True
27
+ )
28
+
29
+ def process_image(image):
30
+ args = {"image_aspect_ratio": "pad"}
31
+ image_tensor = process_images([image], image_processor, args)
32
+ return image_tensor.to(model.device, dtype=torch.float16)
33
+
34
+ processed_image = process_image(image)
35
+ type(processed_image), processed_image.shape
36
+
37
+ CONV_MODE = "llava_v0"
38
+
39
+ def create_prompt(prompt: str):
40
+ conv = conv_templates[CONV_MODE].copy()
41
+ roles = conv.roles
42
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
43
+ conv.append_message(roles[0], prompt)
44
+ conv.append_message(roles[1], None)
45
+ return conv.get_prompt(), conv
46
+
47
+
48
+ prompt, _ = create_prompt("Describe the image")
49
+ print(prompt)
50
+
51
+ def ask_image(image: Image, prompt: str):
52
+ image_tensor = process_image(image)
53
+ prompt, conv = create_prompt(prompt)
54
+ input_ids = (
55
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
56
+ .unsqueeze(0)
57
+ .to(model.device)
58
+ )
59
+
60
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
61
+ stopping_criteria = KeywordsStoppingCriteria(
62
+ keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids
63
+ )
64
+
65
+ with torch.inference_mode():
66
+ output_ids = model.generate(
67
+ input_ids,
68
+ images=image_tensor,
69
+ do_sample=True,
70
+ temperature=0.01,
71
+ max_new_tokens=512,
72
+ use_cache=True,
73
+ stopping_criteria=[stopping_criteria],
74
+ )
75
+ return tokenizer.decode(
76
+ output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
77
+ ).strip()
78
+
79
+
80
+ import gradio as gr
81
+ #import ask_image
82
+ import textwrap
83
+
84
+ # Define the function that takes the image and the text as input and returns the formatted result
85
+ def describe_image(image, text):
86
+ # Generate a description of the image
87
+ result = ask_image(image, text)
88
+
89
+ # Format the result so that it is wrapped to 110 characters per line
90
+ formatted_result = textwrap.fill(result, width=110)
91
+
92
+ return formatted_result
93
+
94
+ # Create a Gradio interface with the following inputs and outputs
95
+ demo = gr.Interface(fn=describe_image, inputs=["image", "text"], outputs="text")
96
+
97
+ # Launch the Gradio interface
98
+ demo.launch(inline=False)