anaspro commited on
Commit
2447f46
·
1 Parent(s): ab5a9ba
Files changed (1) hide show
  1. app.py +200 -132
app.py CHANGED
@@ -1,166 +1,234 @@
1
- import gradio as gr
2
- import cv2
3
- import torch
4
- from PIL import Image
5
- from pathlib import Path
6
- from threading import Thread
7
- from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
8
- import spaces
9
- import time
10
  import os
 
 
 
 
11
 
 
 
 
 
 
 
12
 
13
- # model config - Single model: Shako v4
14
- model_name = "anaspro/Shako-4B-it"
15
-
16
- model_id = "anaspro/Shako-4B-it"
17
  processor = AutoProcessor.from_pretrained(model_id)
18
  model = AutoModelForImageTextToText.from_pretrained(
19
  model_id,
20
  device_map="auto",
21
  torch_dtype=torch.bfloat16
22
- ).eval()
23
- processor = AutoProcessor.from_pretrained(model_name, token=hf_token)
24
- # I will add timestamp later
25
- def extract_video_frames(video_path, num_frames=8):
26
- cap = cv2.VideoCapture(video_path)
27
- frames = []
28
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
29
- step = max(total_frames // num_frames, 1)
30
-
31
- for i in range(num_frames):
32
- cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
33
- ret, frame = cap.read()
34
- if ret:
35
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
36
- frames.append(Image.fromarray(frame))
37
- cap.release()
38
- return frames
39
-
40
- def format_message(content, files):
41
-
42
- message_content = []
43
-
44
- if content:
45
- parts = content.split('<image>')
46
- for i, part in enumerate(parts):
47
- if part.strip():
48
- message_content.append({"type": "text", "text": part.strip()})
49
- if i < len(parts) - 1 and files:
50
- img = Image.open(files.pop(0))
51
- message_content.append({"type": "image", "image": img})
52
- for file in files:
53
- file_path = file if isinstance(file, str) else file.name
54
- if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
55
- img = Image.open(file_path)
56
- message_content.append({"type": "image", "image": img})
57
- elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
58
- frames = extract_video_frames(file_path)
59
- for frame in frames:
60
- message_content.append({"type": "image", "image": frame})
61
- return message_content
62
-
63
- def format_conversation_history(chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  messages = []
65
- current_user_content = []
66
- for item in chat_history:
67
- role = item["role"]
68
- content = item["content"]
69
- if role == "user":
70
- if isinstance(content, str):
71
- current_user_content.append({"type": "text", "text": content})
72
- elif isinstance(content, list):
73
- current_user_content.extend(content)
74
- else:
75
- current_user_content.append({"type": "text", "text": str(content)})
76
- elif role == "assistant":
77
  if current_user_content:
78
  messages.append({"role": "user", "content": current_user_content})
79
  current_user_content = []
80
- messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
81
- if current_user_content:
82
- messages.append({"role": "user", "content": current_user_content})
 
 
 
 
 
 
83
  return messages
84
 
85
- @spaces.GPU() # duration=120
86
- def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
87
- if isinstance(input_data, dict) and "text" in input_data:
88
- text = input_data["text"]
89
- files = input_data.get("files", [])
90
- else:
91
- text = str(input_data)
92
- files = []
93
-
94
- new_message_content = format_message(text, files)
95
- new_message = {"role": "user", "content": new_message_content}
96
- system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
97
- processed_history = format_conversation_history(chat_history)
98
- messages = system_message + processed_history
99
- if messages and messages[-1]["role"] == "user":
100
- messages[-1]["content"].extend(new_message["content"])
101
- else:
102
- messages.append(new_message)
103
- # Use the single Shako v4 model
104
  inputs = processor.apply_chat_template(
105
  messages,
106
  add_generation_prompt=True,
107
  tokenize=True,
 
108
  return_tensors="pt",
109
- return_dict=True
110
- ).to(model.device)
111
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
112
- generation_kwargs = dict(
 
 
 
 
 
 
 
 
 
113
  inputs,
114
  streamer=streamer,
115
  max_new_tokens=max_new_tokens,
116
  do_sample=True,
117
- temperature=temperature,
118
- top_p=top_p,
119
- top_k=top_k,
120
- repetition_penalty=repetition_penalty
 
121
  )
122
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
123
- thread.start()
124
-
125
- outputs = []
126
- for text in streamer:
127
- outputs.append(text)
128
- yield "".join(outputs)
129
 
 
 
 
 
 
 
 
 
 
130
  demo = gr.ChatInterface(
131
- fn=generate_response,
132
- additional_inputs=[
133
- gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
134
- gr.Textbox(
135
- label="System Prompt",
136
- value="You are a friendly chatbot. ",
137
- lines=4,
138
- placeholder="Change system prompt"
139
- ),
140
- gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
141
- gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
142
- gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
143
- gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0),
144
- ],
145
- examples=[
146
- [{"text": "Explain this image", "files": ["examples/image1.jpg"]}],
147
- ],
148
- cache_examples=False,
149
  type="messages",
150
- description="""
151
- # شكو - Shako Iraqi AI
152
- نموذج ذكاء عراقي متقدم يتحدث بالعراقي، يدعم الصور والفيديوهات والمحادثات الصوتية.
153
- """,
154
- fill_height=True,
155
  textbox=gr.MultimodalTextbox(
156
- label="Query Input",
157
- file_types=["image", "video"],
158
  file_count="multiple",
159
- placeholder="Type your message or upload media"
160
  ),
161
- stop_btn="Stop Generation",
162
  multimodal=True,
163
- theme=gr.themes.Soft(),
 
 
 
 
 
 
164
  )
165
 
166
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import pathlib
3
+ import tempfile
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
 
7
+ import av
8
+ import gradio as gr
9
+ import spaces
10
+ import torch
11
+ from transformers import AutoModelForImageTextToText, AutoProcessor
12
+ from transformers.generation.streamers import TextIteratorStreamer
13
 
14
+ # Model configuration
15
+ model_id = "anaspro/Shako-4B-it-v3"
 
 
16
  processor = AutoProcessor.from_pretrained(model_id)
17
  model = AutoModelForImageTextToText.from_pretrained(
18
  model_id,
19
  device_map="auto",
20
  torch_dtype=torch.bfloat16
21
+ )
22
+
23
+ # Supported file types
24
+ IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
25
+ VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
26
+ AUDIO_FILE_TYPES = (".mp3", ".wav")
27
+
28
+ # Video processing settings
29
+ TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
30
+ MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
31
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
32
+
33
+
34
+ def get_file_type(path: str) -> str:
35
+ if path.endswith(IMAGE_FILE_TYPES):
36
+ return "image"
37
+ if path.endswith(VIDEO_FILE_TYPES):
38
+ return "video"
39
+ if path.endswith(AUDIO_FILE_TYPES):
40
+ return "audio"
41
+ error_message = f"Unsupported file type: {path}"
42
+ raise ValueError(error_message)
43
+
44
+
45
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
46
+ video_count = 0
47
+ non_video_count = 0
48
+ for path in paths:
49
+ if path.endswith(VIDEO_FILE_TYPES):
50
+ video_count += 1
51
+ else:
52
+ non_video_count += 1
53
+ return video_count, non_video_count
54
+
55
+
56
+ def validate_media_constraints(message: dict) -> bool:
57
+ video_count, non_video_count = count_files_in_new_message(message["files"])
58
+ if video_count > 1:
59
+ gr.Warning("Only one video is supported.")
60
+ return False
61
+ if video_count == 1 and non_video_count > 0:
62
+ gr.Warning("Mixing images and videos is not allowed.")
63
+ return False
64
+ return True
65
+
66
+
67
+ def extract_frames_to_tempdir(
68
+ video_path: str,
69
+ target_fps: float,
70
+ max_frames: int | None = None,
71
+ parent_dir: str | None = None,
72
+ prefix: str = "frames_",
73
+ ) -> str:
74
+ temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
75
+
76
+ container = av.open(video_path)
77
+ video_stream = container.streams.video[0]
78
+
79
+ if video_stream.duration is None or video_stream.time_base is None:
80
+ raise ValueError("video_stream is missing duration or time_base")
81
+
82
+ time_base = video_stream.time_base
83
+ duration = float(video_stream.duration * time_base)
84
+ interval = 1.0 / target_fps
85
+
86
+ total_frames = int(duration * target_fps)
87
+ if max_frames is not None:
88
+ total_frames = min(total_frames, max_frames)
89
+
90
+ target_times = [i * interval for i in range(total_frames)]
91
+ target_index = 0
92
+
93
+ for frame in container.decode(video=0):
94
+ if frame.pts is None:
95
+ continue
96
+
97
+ timestamp = float(frame.pts * time_base)
98
+
99
+ if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
100
+ frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
101
+ frame.to_image().save(frame_path)
102
+ target_index += 1
103
+
104
+ if max_frames is not None and target_index >= max_frames:
105
+ break
106
+
107
+ container.close()
108
+ return temp_dir
109
+
110
+
111
+ def process_new_user_message(message: dict) -> list[dict]:
112
+ if not message["files"]:
113
+ return [{"type": "text", "text": message["text"]}]
114
+
115
+ file_types = [get_file_type(path) for path in message["files"]]
116
+
117
+ if len(file_types) == 1 and file_types[0] == "video":
118
+ gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
119
+
120
+ temp_dir = extract_frames_to_tempdir(
121
+ message["files"][0],
122
+ target_fps=TARGET_FPS,
123
+ max_frames=MAX_FRAMES,
124
+ )
125
+ paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
126
+ return [
127
+ {"type": "text", "text": message["text"]},
128
+ *[{"type": "image", "image": path.as_posix()} for path in paths],
129
+ ]
130
+
131
+ return [
132
+ {"type": "text", "text": message["text"]},
133
+ *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
134
+ ]
135
+
136
+
137
+ def process_history(history: list[dict]) -> list[dict]:
138
  messages = []
139
+ current_user_content: list[dict] = []
140
+ for item in history:
141
+ if item["role"] == "assistant":
 
 
 
 
 
 
 
 
 
142
  if current_user_content:
143
  messages.append({"role": "user", "content": current_user_content})
144
  current_user_content = []
145
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
146
+ else:
147
+ content = item["content"]
148
+ if isinstance(content, str):
149
+ current_user_content.append({"type": "text", "text": content})
150
+ else:
151
+ filepath = content[0]
152
+ file_type = get_file_type(filepath)
153
+ current_user_content.append({"type": file_type, file_type: filepath})
154
  return messages
155
 
156
+
157
+ @spaces.GPU()
158
+ @torch.inference_mode()
159
+ def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
160
+ if not validate_media_constraints(message):
161
+ yield ""
162
+ return
163
+
164
+ messages = []
165
+ if system_prompt:
166
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
167
+ messages.extend(process_history(history))
168
+ messages.append({"role": "user", "content": process_new_user_message(message)})
169
+
 
 
 
 
 
170
  inputs = processor.apply_chat_template(
171
  messages,
172
  add_generation_prompt=True,
173
  tokenize=True,
174
+ return_dict=True,
175
  return_tensors="pt",
176
+ )
177
+ n_tokens = inputs["input_ids"].shape[1]
178
+ if n_tokens > MAX_INPUT_TOKENS:
179
+ gr.Warning(
180
+ f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space."
181
+ )
182
+ yield ""
183
+ return
184
+
185
+ inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
186
+
187
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
188
+ generate_kwargs = dict(
189
  inputs,
190
  streamer=streamer,
191
  max_new_tokens=max_new_tokens,
192
  do_sample=True,
193
+ temperature=1.0,
194
+ top_k=64,
195
+ top_p=0.95,
196
+ min_p=0.0,
197
+ disable_compile=True,
198
  )
199
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
200
+ t.start()
201
+
202
+ output = ""
203
+ for delta in streamer:
204
+ output += delta
205
+ yield output
206
 
207
+
208
+ # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
209
+ examples = [
210
+ ["What is the capital of France?", "You are a helpful assistant.", 700],
211
+ ["Explain quantum computing in simple terms", "You are a helpful assistant.", 512],
212
+ ["Write a short story about a robot learning to paint", "You are a helpful assistant.", 1000]
213
+ ]
214
+
215
+ # Create the chat interface
216
  demo = gr.ChatInterface(
217
+ fn=generate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  type="messages",
 
 
 
 
 
219
  textbox=gr.MultimodalTextbox(
220
+ file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
 
221
  file_count="multiple",
222
+ autofocus=True,
223
  ),
 
224
  multimodal=True,
225
+ additional_inputs=[
226
+ gr.Textbox(label="System Prompt", value="انت موديل عراقي عادي من بغداد، ذكي ومرح. تتحدث بالعراقي فقط وتجاوب بتفصيل حسب السؤال. ما تستخدم فصحى ابدا."),
227
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
228
+ ],
229
+ title="Shako IRAQI AI",
230
+ examples=examples,
231
+ stop_btn=False,
232
  )
233
 
234
  if __name__ == "__main__":