anaspro commited on
Commit
f202e6b
·
1 Parent(s): 6b28f41
Files changed (4) hide show
  1. app.py +147 -203
  2. app2.py +10 -4
  3. examples/image1.jpg +0 -0
  4. requirements.txt +7 -8
app.py CHANGED
@@ -1,234 +1,178 @@
1
- import os
2
- import pathlib
3
- import tempfile
4
- from collections.abc import Iterator
5
- from threading import Thread
6
-
7
- import av
8
  import gradio as gr
9
- import spaces
10
  import torch
11
- from transformers import AutoModelForImageTextToText, AutoProcessor
12
- from transformers.generation.streamers import TextIteratorStreamer
 
 
 
 
13
 
14
- # Model configuration
15
- model_id = "anaspro/Shako-4B-it-v3"
16
- processor = AutoProcessor.from_pretrained(model_id)
17
- model = AutoModelForImageTextToText.from_pretrained(
18
- model_id,
19
  device_map="auto",
20
  torch_dtype=torch.bfloat16
21
- )
22
-
23
- # Supported file types
24
- IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
25
- VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
26
- AUDIO_FILE_TYPES = (".mp3", ".wav")
27
-
28
- # Video processing settings
29
- TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
30
- MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
31
- MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
32
-
33
-
34
- def get_file_type(path: str) -> str:
35
- if path.endswith(IMAGE_FILE_TYPES):
36
- return "image"
37
- if path.endswith(VIDEO_FILE_TYPES):
38
- return "video"
39
- if path.endswith(AUDIO_FILE_TYPES):
40
- return "audio"
41
- error_message = f"Unsupported file type: {path}"
42
- raise ValueError(error_message)
43
-
44
-
45
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
46
- video_count = 0
47
- non_video_count = 0
48
- for path in paths:
49
- if path.endswith(VIDEO_FILE_TYPES):
50
- video_count += 1
51
- else:
52
- non_video_count += 1
53
- return video_count, non_video_count
54
-
55
-
56
- def validate_media_constraints(message: dict) -> bool:
57
- video_count, non_video_count = count_files_in_new_message(message["files"])
58
- if video_count > 1:
59
- gr.Warning("Only one video is supported.")
60
- return False
61
- if video_count == 1 and non_video_count > 0:
62
- gr.Warning("Mixing images and videos is not allowed.")
63
- return False
64
- return True
65
-
66
-
67
- def extract_frames_to_tempdir(
68
- video_path: str,
69
- target_fps: float,
70
- max_frames: int | None = None,
71
- parent_dir: str | None = None,
72
- prefix: str = "frames_",
73
- ) -> str:
74
- temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
75
-
76
- container = av.open(video_path)
77
- video_stream = container.streams.video[0]
78
-
79
- if video_stream.duration is None or video_stream.time_base is None:
80
- raise ValueError("video_stream is missing duration or time_base")
81
-
82
- time_base = video_stream.time_base
83
- duration = float(video_stream.duration * time_base)
84
- interval = 1.0 / target_fps
85
-
86
- total_frames = int(duration * target_fps)
87
- if max_frames is not None:
88
- total_frames = min(total_frames, max_frames)
89
-
90
- target_times = [i * interval for i in range(total_frames)]
91
- target_index = 0
92
-
93
- for frame in container.decode(video=0):
94
- if frame.pts is None:
95
- continue
96
-
97
- timestamp = float(frame.pts * time_base)
98
-
99
- if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
100
- frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
101
- frame.to_image().save(frame_path)
102
- target_index += 1
103
-
104
- if max_frames is not None and target_index >= max_frames:
105
- break
106
-
107
- container.close()
108
- return temp_dir
109
-
110
-
111
- def process_new_user_message(message: dict) -> list[dict]:
112
- if not message["files"]:
113
- return [{"type": "text", "text": message["text"]}]
114
-
115
- file_types = [get_file_type(path) for path in message["files"]]
116
-
117
- if len(file_types) == 1 and file_types[0] == "video":
118
- gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
119
-
120
- temp_dir = extract_frames_to_tempdir(
121
- message["files"][0],
122
- target_fps=TARGET_FPS,
123
- max_frames=MAX_FRAMES,
124
- )
125
- paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
126
- return [
127
- {"type": "text", "text": message["text"]},
128
- *[{"type": "image", "image": path.as_posix()} for path in paths],
129
- ]
130
-
131
- return [
132
- {"type": "text", "text": message["text"]},
133
- *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
134
- ]
135
-
136
-
137
- def process_history(history: list[dict]) -> list[dict]:
138
  messages = []
139
- current_user_content: list[dict] = []
140
- for item in history:
141
- if item["role"] == "assistant":
142
- if current_user_content:
143
- messages.append({"role": "user", "content": current_user_content})
144
- current_user_content = []
145
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
146
- else:
147
- content = item["content"]
148
  if isinstance(content, str):
149
  current_user_content.append({"type": "text", "text": content})
 
 
150
  else:
151
- filepath = content[0]
152
- file_type = get_file_type(filepath)
153
- current_user_content.append({"type": file_type, file_type: filepath})
 
 
 
 
 
154
  return messages
155
 
156
-
157
- @spaces.GPU()
158
- @torch.inference_mode()
159
- def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
160
- if not validate_media_constraints(message):
161
- yield ""
162
- return
163
-
164
- messages = []
165
- if system_prompt:
166
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
167
- messages.extend(process_history(history))
168
- messages.append({"role": "user", "content": process_new_user_message(message)})
169
-
 
 
 
 
 
 
 
 
 
 
170
  inputs = processor.apply_chat_template(
171
  messages,
172
  add_generation_prompt=True,
173
  tokenize=True,
174
- return_dict=True,
175
  return_tensors="pt",
176
- )
177
- n_tokens = inputs["input_ids"].shape[1]
178
- if n_tokens > MAX_INPUT_TOKENS:
179
- gr.Warning(
180
- f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space."
181
- )
182
- yield ""
183
- return
184
-
185
- inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
186
-
187
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
188
- generate_kwargs = dict(
189
  inputs,
190
  streamer=streamer,
191
  max_new_tokens=max_new_tokens,
192
  do_sample=True,
193
- temperature=1.0,
194
- top_k=64,
195
- top_p=0.95,
196
- min_p=0.0,
197
- disable_compile=True,
198
  )
199
- t = Thread(target=model.generate, kwargs=generate_kwargs)
200
- t.start()
201
-
202
- output = ""
203
- for delta in streamer:
204
- output += delta
205
- yield output
206
 
207
-
208
- # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
209
- examples = [
210
- ["What is the capital of France?", "You are a helpful assistant.", 700],
211
- ["Explain quantum computing in simple terms", "You are a helpful assistant.", 512],
212
- ["Write a short story about a robot learning to paint", "You are a helpful assistant.", 1000]
213
- ]
214
-
215
- # Create the chat interface
216
  demo = gr.ChatInterface(
217
- fn=generate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  type="messages",
 
 
 
 
 
219
  textbox=gr.MultimodalTextbox(
220
- file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
 
221
  file_count="multiple",
222
- autofocus=True,
223
  ),
 
224
  multimodal=True,
225
- additional_inputs=[
226
- gr.Textbox(label="System Prompt", value="انت موديل عراقي عادي من بغداد، ذكي ومرح. تتحدث بالعراقي فقط وتجاوب بتفصيل حسب السؤال. ما تستخدم فصحى ابدا."),
227
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
228
- ],
229
- title="Shako IRAQI AI",
230
- examples=examples,
231
- stop_btn=False,
232
  )
233
 
234
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import cv2
3
  import torch
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from threading import Thread
7
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
8
+ import spaces
9
+ import time
10
 
11
+ # model config
12
+ model_12b_name = "google/gemma-3-12b-it"
13
+ model_4b_name = "google/gemma-3-4b-it"
14
+ model_12b = Gemma3ForConditionalGeneration.from_pretrained(
15
+ model_12b_name,
16
  device_map="auto",
17
  torch_dtype=torch.bfloat16
18
+ ).eval()
19
+ processor_12b = AutoProcessor.from_pretrained(model_12b_name)
20
+ model_4b = Gemma3ForConditionalGeneration.from_pretrained(
21
+ model_4b_name,
22
+ device_map="auto",
23
+ torch_dtype=torch.bfloat16
24
+ ).eval()
25
+ processor_4b = AutoProcessor.from_pretrained(model_4b_name)
26
+ # I will add timestamp later
27
+ def extract_video_frames(video_path, num_frames=8):
28
+ cap = cv2.VideoCapture(video_path)
29
+ frames = []
30
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
31
+ step = max(total_frames // num_frames, 1)
32
+
33
+ for i in range(num_frames):
34
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
35
+ ret, frame = cap.read()
36
+ if ret:
37
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
38
+ frames.append(Image.fromarray(frame))
39
+ cap.release()
40
+ return frames
41
+
42
+ def format_message(content, files):
43
+
44
+ message_content = []
45
+
46
+ if content:
47
+ parts = content.split('<image>')
48
+ for i, part in enumerate(parts):
49
+ if part.strip():
50
+ message_content.append({"type": "text", "text": part.strip()})
51
+ if i < len(parts) - 1 and files:
52
+ img = Image.open(files.pop(0))
53
+ message_content.append({"type": "image", "image": img})
54
+ for file in files:
55
+ file_path = file if isinstance(file, str) else file.name
56
+ if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
57
+ img = Image.open(file_path)
58
+ message_content.append({"type": "image", "image": img})
59
+ elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
60
+ frames = extract_video_frames(file_path)
61
+ for frame in frames:
62
+ message_content.append({"type": "image", "image": frame})
63
+ return message_content
64
+
65
+ def format_conversation_history(chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  messages = []
67
+ current_user_content = []
68
+ for item in chat_history:
69
+ role = item["role"]
70
+ content = item["content"]
71
+ if role == "user":
 
 
 
 
72
  if isinstance(content, str):
73
  current_user_content.append({"type": "text", "text": content})
74
+ elif isinstance(content, list):
75
+ current_user_content.extend(content)
76
  else:
77
+ current_user_content.append({"type": "text", "text": str(content)})
78
+ elif role == "assistant":
79
+ if current_user_content:
80
+ messages.append({"role": "user", "content": current_user_content})
81
+ current_user_content = []
82
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
83
+ if current_user_content:
84
+ messages.append({"role": "user", "content": current_user_content})
85
  return messages
86
 
87
+ @spaces.GPU(duration=120)
88
+ def generate_response(input_data, chat_history, model_choice, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
89
+ if isinstance(input_data, dict) and "text" in input_data:
90
+ text = input_data["text"]
91
+ files = input_data.get("files", [])
92
+ else:
93
+ text = str(input_data)
94
+ files = []
95
+
96
+ new_message_content = format_message(text, files)
97
+ new_message = {"role": "user", "content": new_message_content}
98
+ system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
99
+ processed_history = format_conversation_history(chat_history)
100
+ messages = system_message + processed_history
101
+ if messages and messages[-1]["role"] == "user":
102
+ messages[-1]["content"].extend(new_message["content"])
103
+ else:
104
+ messages.append(new_message)
105
+ if model_choice == "Gemma 3 12B":
106
+ model = model_12b
107
+ processor = processor_12b
108
+ else:
109
+ model = model_4b
110
+ processor = processor_4b
111
  inputs = processor.apply_chat_template(
112
  messages,
113
  add_generation_prompt=True,
114
  tokenize=True,
 
115
  return_tensors="pt",
116
+ return_dict=True
117
+ ).to(model.device)
118
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
119
+ generation_kwargs = dict(
 
 
 
 
 
 
 
 
 
120
  inputs,
121
  streamer=streamer,
122
  max_new_tokens=max_new_tokens,
123
  do_sample=True,
124
+ temperature=temperature,
125
+ top_p=top_p,
126
+ top_k=top_k,
127
+ repetition_penalty=repetition_penalty
 
128
  )
129
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
130
+ thread.start()
131
+
132
+ outputs = []
133
+ for text in streamer:
134
+ outputs.append(text)
135
+ yield "".join(outputs)
136
 
 
 
 
 
 
 
 
 
 
137
  demo = gr.ChatInterface(
138
+ fn=generate_response,
139
+ additional_inputs=[
140
+ gr.Dropdown(
141
+ label="Model",
142
+ choices=["Gemma 3 12B", "Gemma 3 4B"],
143
+ value="Gemma 3 12B"
144
+ ),
145
+ gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
146
+ gr.Textbox(
147
+ label="System Prompt",
148
+ value="You are a friendly chatbot. ",
149
+ lines=4,
150
+ placeholder="Change system prompt"
151
+ ),
152
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
153
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
154
+ gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
155
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0),
156
+ ],
157
+ examples=[
158
+ [{"text": "Explain this image", "files": ["examples/image1.jpg"]}],
159
+ ],
160
+ cache_examples=False,
161
  type="messages",
162
+ description="""
163
+ # Gemma 3
164
+ You can pick your model 12B or 4B, upload images or videos, and adjust settings below to customize your experience.
165
+ """,
166
+ fill_height=True,
167
  textbox=gr.MultimodalTextbox(
168
+ label="Query Input",
169
+ file_types=["image", "video"],
170
  file_count="multiple",
171
+ placeholder="Type your message or upload media"
172
  ),
173
+ stop_btn="Stop Generation",
174
  multimodal=True,
175
+ theme=gr.themes.Soft(),
 
 
 
 
 
 
176
  )
177
 
178
  if __name__ == "__main__":
app2.py CHANGED
@@ -12,7 +12,7 @@ from transformers import AutoModelForImageTextToText, AutoProcessor
12
  from transformers.generation.streamers import TextIteratorStreamer
13
 
14
  # Model configuration
15
- model_id = "anaspro/Shako-4B-it-v2"
16
  processor = AutoProcessor.from_pretrained(model_id)
17
  model = AutoModelForImageTextToText.from_pretrained(
18
  model_id,
@@ -189,7 +189,11 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
189
  inputs,
190
  streamer=streamer,
191
  max_new_tokens=max_new_tokens,
192
- do_sample=False,
 
 
 
 
193
  disable_compile=True,
194
  )
195
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -203,7 +207,9 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
203
 
204
  # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
205
  examples = [
206
- ["انت موديل عراقي تحكي هعراقي فقط وتكون ترفيهي", 700]
 
 
207
  ]
208
 
209
  # Create the chat interface
@@ -217,7 +223,7 @@ demo = gr.ChatInterface(
217
  ),
218
  multimodal=True,
219
  additional_inputs=[
220
- gr.Textbox(label="System Prompt", value="انت ذكاء صناعي يتحدث باللهجة العراقية بس ما تستخدم فصحى ابدا"),
221
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
222
  ],
223
  title="Shako IRAQI AI",
 
12
  from transformers.generation.streamers import TextIteratorStreamer
13
 
14
  # Model configuration
15
+ model_id = "anaspro/Shako-4B-it-v3"
16
  processor = AutoProcessor.from_pretrained(model_id)
17
  model = AutoModelForImageTextToText.from_pretrained(
18
  model_id,
 
189
  inputs,
190
  streamer=streamer,
191
  max_new_tokens=max_new_tokens,
192
+ do_sample=True,
193
+ temperature=1.0,
194
+ top_k=64,
195
+ top_p=0.95,
196
+ min_p=0.0,
197
  disable_compile=True,
198
  )
199
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
207
 
208
  # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
209
  examples = [
210
+ ["What is the capital of France?", "You are a helpful assistant.", 700],
211
+ ["Explain quantum computing in simple terms", "You are a helpful assistant.", 512],
212
+ ["Write a short story about a robot learning to paint", "You are a helpful assistant.", 1000]
213
  ]
214
 
215
  # Create the chat interface
 
223
  ),
224
  multimodal=True,
225
  additional_inputs=[
226
+ gr.Textbox(label="System Prompt", value="انت موديل عراقي عادي من بغداد، ذكي ومرح. تتحدث بالعراقي فقط وتجاوب بتفصيل حسب السؤال. ما تستخدم فصحى ابدا."),
227
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
228
  ],
229
  title="Shako IRAQI AI",
examples/image1.jpg ADDED
requirements.txt CHANGED
@@ -1,8 +1,7 @@
1
- gradio>=4.0.0
2
- spaces[huggingface]>=0.28.0
3
- transformers>=4.35.0
4
- torch>=2.1.0
5
- av
6
- accelerate>=0.25.0
7
- timm
8
- gTTS>=2.5.0
 
1
+ transformers
2
+ spaces
3
+ torch
4
+ transformers @ git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
5
+ pillow
6
+ opencv-python
7
+ accelerate