Artem commited on
Commit
2323b4d
·
1 Parent(s): 03fd523

model switching

Browse files
eval.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+
future_work/adapters.py CHANGED
@@ -1,4 +1,3 @@
1
- from transformers import TextStreamer
2
  from unsloth import FastVisionModel
3
  from dotenv import load_dotenv
4
  import os
 
 
1
  from unsloth import FastVisionModel
2
  from dotenv import load_dotenv
3
  import os
future_work/dataset.py CHANGED
@@ -1,15 +1,14 @@
1
  from datasets import Dataset
2
- import torch
3
  from consts import REASONING_START, REASONING_END, SOLUTION_START, SOLUTION_END
4
 
5
 
6
 
7
  def is_numeric_answer(example):
8
- try:
9
- float(example["answer"])
10
- return True
11
- except:
12
- return False
13
 
14
  def resize_images(example):
15
  image = example["decoded_image"]
 
1
  from datasets import Dataset
 
2
  from consts import REASONING_START, REASONING_END, SOLUTION_START, SOLUTION_END
3
 
4
 
5
 
6
  def is_numeric_answer(example):
7
+ try:
8
+ float(example["answer"])
9
+ return True
10
+ except Exception as e:
11
+ return f"error: {e}"
12
 
13
  def resize_images(example):
14
  image = example["decoded_image"]
future_work/model.py CHANGED
@@ -1,5 +1,4 @@
1
  from unsloth import FastVisionModel
2
- import torch
3
  from consts import BASE_MODEL
4
 
5
 
 
1
  from unsloth import FastVisionModel
 
2
  from consts import BASE_MODEL
3
 
4
 
gradio_app.py CHANGED
@@ -1,74 +1,15 @@
1
- import torch
2
  import gradio as gr
3
- from PIL import Image
4
- from consts import BASE_MODEL
5
- from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
6
- from qwen_vl_utils import process_vision_info
7
 
 
 
 
 
 
 
8
 
9
- """
10
- Initalize Model
11
- """
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- model = Qwen2VLForConditionalGeneration.from_pretrained(BASE_MODEL)
15
- processor = AutoProcessor.from_pretrained(BASE_MODEL)
16
-
17
-
18
- """
19
- Model Function
20
- """
21
- def query(image: Image.Image, question: str):
22
- if image is None:
23
- return "Upload an image bro."
24
-
25
- messages = [
26
- {
27
- "role": "user",
28
- "content": [
29
- {"type": "image", "image": image},
30
- {"type": "text", "text": question}
31
- ]
32
- }
33
- ]
34
-
35
- text = processor.apply_chat_template(
36
- messages,
37
- tokenize=False,
38
- add_generation_prompt=True
39
- )
40
-
41
- images, video_inputs = process_vision_info(messages)
42
-
43
- inputs = processor(
44
- text=text,
45
- images=images,
46
- videos=video_inputs,
47
- padding=True,
48
- return_tensors="pt")
49
-
50
- # Generate output
51
- generated_ids = model.generate(**inputs, max_new_tokens=256)
52
-
53
- # Trim the input tokens
54
- generated_ids_trimmed = [
55
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
56
- ]
57
-
58
- # Decode the output
59
- output_text = processor.batch_decode(
60
- generated_ids_trimmed,
61
- skip_special_tokens=True,
62
- clean_up_tokenization_spaces=False
63
- )
64
-
65
- return output_text[0]
66
-
67
-
68
-
69
- """
70
- Interface
71
- """
72
 
73
  custom_css = """
74
  .output-card {
@@ -81,55 +22,49 @@ custom_css = """
81
 
82
  with gr.Blocks(theme=gr.themes.Soft(), title="Qwen2-VL Analyst") as app:
83
 
84
- # Header
 
85
  gr.Markdown(
86
  r"""
87
- ¯\(ツ)/¯ Intelligence: Upload an image and ask a question
88
  """
89
  )
90
-
91
  with gr.Row():
92
- # Inputs
93
  with gr.Column(scale=1):
94
  img_input = gr.Image(type="pil", label="Upload Image", height=400)
95
- q_input = gr.Textbox(
96
- label="Question",
97
- lines=2
98
- )
99
-
100
  with gr.Row():
101
  clear_btn = gr.Button("Clear", variant="secondary")
102
  submit_btn = gr.Button("Analyze Image", variant="primary")
103
-
104
- # Output
105
  with gr.Column(scale=1):
 
 
 
 
106
  gr.Markdown("Model Analysis:")
107
-
108
  with gr.Group(elem_classes="output-card"):
109
- output_box = gr.Markdown(
110
- value="Results...",
111
- line_breaks=True
112
- )
113
 
114
- # Trigger on Button Click
115
  submit_btn.click(
116
- fn=query,
117
- inputs=[img_input, q_input],
118
- outputs=output_box
119
  )
120
-
121
- # Trigger on pressing Enter
 
122
  q_input.submit(
123
- fn=query,
124
- inputs=[img_input, q_input],
125
- outputs=output_box
126
  )
127
 
128
- # Clear button
 
129
  def clear_inputs():
130
  return None, "", ""
131
-
132
- clear_btn.click(fn=clear_inputs, inputs=[], outputs=[img_input, q_input, output_box])
 
 
133
 
134
 
135
  app.launch()
 
 
1
  import gradio as gr
2
+ from local_model import query_local
3
+ from remote_model import query_remote, pipe
4
+ import time
 
5
 
6
+ def query(image, question, model_name):
7
+ if model_name == "Local":
8
+ return query_local(image, question)
9
+ elif model_name == "Remote":
10
+ return query_remote(image, question, pipe)
11
+ return "No model selected"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  custom_css = """
15
  .output-card {
 
22
 
23
  with gr.Blocks(theme=gr.themes.Soft(), title="Qwen2-VL Analyst") as app:
24
 
25
+ start_time = time.time()
26
+
27
  gr.Markdown(
28
  r"""
29
+ ¯\_(ツ)_/¯ Intelligence: Upload an image and ask a question
30
  """
31
  )
32
+
33
  with gr.Row():
 
34
  with gr.Column(scale=1):
35
  img_input = gr.Image(type="pil", label="Upload Image", height=400)
36
+ q_input = gr.Textbox(label="Question", lines=2)
 
 
 
 
37
  with gr.Row():
38
  clear_btn = gr.Button("Clear", variant="secondary")
39
  submit_btn = gr.Button("Analyze Image", variant="primary")
 
 
40
  with gr.Column(scale=1):
41
+ with gr.Row():
42
+ model_dropdown = gr.Dropdown(
43
+ label="Select Model", choices=["Local", "Remote"], value="Local"
44
+ )
45
  gr.Markdown("Model Analysis:")
46
+
47
  with gr.Group(elem_classes="output-card"):
48
+ output_box = gr.Markdown(value="Results...", line_breaks=True)
 
 
 
49
 
 
50
  submit_btn.click(
51
+ fn=query, inputs=[img_input, q_input, model_dropdown], outputs=output_box
 
 
52
  )
53
+
54
+
55
+
56
  q_input.submit(
57
+ fn=query, inputs=[img_input, q_input, model_dropdown], outputs=output_box
 
 
58
  )
59
 
60
+
61
+
62
  def clear_inputs():
63
  return None, "", ""
64
+
65
+ clear_btn.click(
66
+ fn=clear_inputs, inputs=[], outputs=[img_input, q_input, output_box]
67
+ )
68
 
69
 
70
  app.launch()
local_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ from consts import BASE_MODEL
5
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
6
+ from qwen_vl_utils import process_vision_info
7
+ import time
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model = Qwen2VLForConditionalGeneration.from_pretrained(BASE_MODEL)
10
+ processor = AutoProcessor.from_pretrained(BASE_MODEL)
11
+
12
+
13
+ def query_local(image: Image.Image, question: str):
14
+ start_time = time.time()
15
+ if image is None:
16
+ raise ValueError("Missing image")
17
+
18
+ messages = [
19
+ {
20
+ "role": "user",
21
+ "content": [
22
+ {"type": "image", "image": image},
23
+ {"type": "text", "text": question}
24
+ ]
25
+ }
26
+ ]
27
+
28
+ text = processor.apply_chat_template(
29
+ messages,
30
+ tokenize=False,
31
+ add_generation_prompt=True
32
+ )
33
+
34
+ images, video_inputs = process_vision_info(messages)
35
+
36
+ inputs = processor(
37
+ text=text,
38
+ images=images,
39
+ videos=video_inputs,
40
+ padding=True,
41
+ return_tensors="pt")
42
+
43
+ generated_ids = model.generate(**inputs, max_new_tokens=256)
44
+
45
+ print("inputs generated")
46
+ generated_ids_trimmed = [
47
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
48
+ ]
49
+ print("trimmed")
50
+
51
+ output_text = processor.batch_decode(
52
+ generated_ids_trimmed,
53
+ skip_special_tokens=True,
54
+ clean_up_tokenization_spaces=False
55
+ )
56
+
57
+ print("decoded")
58
+
59
+ print("local %s --- " % (time.time() - start_time))
60
+
61
+ return output_text[0]
62
+
remote_model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import huggingface_hub
3
+ from consts import BASE_MODEL
4
+ from PIL import Image
5
+ from transformers import pipeline
6
+ import time
7
+
8
+
9
+ pipe = pipeline("image-text-to-text", model = BASE_MODEL)
10
+
11
+
12
+ def query_remote(image: Image.Image, question: str, pipe):
13
+ start_time = time.time()
14
+ if not Image:
15
+ raise ValueError("Missing image")
16
+
17
+ messages = [
18
+ {
19
+ "role": "user",
20
+ "content": [
21
+ {"type": "image", "image": image},
22
+ {"type": "text", "text": question}
23
+ ]
24
+ }
25
+ ]
26
+
27
+ outputs = pipe(text=messages, return_full_text=False)
28
+
29
+ print("remote time %s --- " % (time.time() - start_time))
30
+
31
+ return outputs[0]["generated_text"]