anaspro commited on
Commit
3f9e0fd
·
verified ·
1 Parent(s): 5ac765e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +252 -0
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import tempfile
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
+
7
+ import av
8
+ import gradio as gr
9
+ import spaces
10
+ import torch
11
+ from transformers import AutoModelForImageTextToText, AutoProcessor
12
+ from transformers.generation.streamers import TextIteratorStreamer
13
+
14
+ # Model configuration
15
+ model_id = "anaspro/Shako-iraqi-4B-it"
16
+ processor = AutoProcessor.from_pretrained(model_id)
17
+ model = AutoModelForImageTextToText.from_pretrained(
18
+ model_id,
19
+ device_map="auto",
20
+ torch_dtype=torch.bfloat16
21
+ )
22
+
23
+ # Supported file types
24
+ IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
25
+ VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
26
+ AUDIO_FILE_TYPES = (".mp3", ".wav")
27
+
28
+ # Video processing settings
29
+ TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
30
+ MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
31
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
32
+
33
+
34
+ def get_file_type(path: str) -> str:
35
+ if path.endswith(IMAGE_FILE_TYPES):
36
+ return "image"
37
+ if path.endswith(VIDEO_FILE_TYPES):
38
+ return "video"
39
+ if path.endswith(AUDIO_FILE_TYPES):
40
+ return "audio"
41
+ error_message = f"Unsupported file type: {path}"
42
+ raise ValueError(error_message)
43
+
44
+
45
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
46
+ video_count = 0
47
+ non_video_count = 0
48
+ for path in paths:
49
+ if path.endswith(VIDEO_FILE_TYPES):
50
+ video_count += 1
51
+ else:
52
+ non_video_count += 1
53
+ return video_count, non_video_count
54
+
55
+
56
+ def validate_media_constraints(message: dict) -> bool:
57
+ video_count, non_video_count = count_files_in_new_message(message["files"])
58
+ if video_count > 1:
59
+ gr.Warning("Only one video is supported.")
60
+ return False
61
+ if video_count == 1 and non_video_count > 0:
62
+ gr.Warning("Mixing images and videos is not allowed.")
63
+ return False
64
+ return True
65
+
66
+
67
+ def extract_frames_to_tempdir(
68
+ video_path: str,
69
+ target_fps: float,
70
+ max_frames: int | None = None,
71
+ parent_dir: str | None = None,
72
+ prefix: str = "frames_",
73
+ ) -> str:
74
+ temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
75
+
76
+ container = av.open(video_path)
77
+ video_stream = container.streams.video[0]
78
+
79
+ if video_stream.duration is None or video_stream.time_base is None:
80
+ raise ValueError("video_stream is missing duration or time_base")
81
+
82
+ time_base = video_stream.time_base
83
+ duration = float(video_stream.duration * time_base)
84
+ interval = 1.0 / target_fps
85
+
86
+ total_frames = int(duration * target_fps)
87
+ if max_frames is not None:
88
+ total_frames = min(total_frames, max_frames)
89
+
90
+ target_times = [i * interval for i in range(total_frames)]
91
+ target_index = 0
92
+
93
+ for frame in container.decode(video=0):
94
+ if frame.pts is None:
95
+ continue
96
+
97
+ timestamp = float(frame.pts * time_base)
98
+
99
+ if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
100
+ frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
101
+ frame.to_image().save(frame_path)
102
+ target_index += 1
103
+
104
+ if max_frames is not None and target_index >= max_frames:
105
+ break
106
+
107
+ container.close()
108
+ return temp_dir
109
+
110
+
111
+ def process_new_user_message(message: dict) -> list[dict]:
112
+ if not message["files"]:
113
+ return [{"type": "text", "text": message["text"]}]
114
+
115
+ file_types = [get_file_type(path) for path in message["files"]]
116
+
117
+ if len(file_types) == 1 and file_types[0] == "video":
118
+ gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
119
+
120
+ temp_dir = extract_frames_to_tempdir(
121
+ message["files"][0],
122
+ target_fps=TARGET_FPS,
123
+ max_frames=MAX_FRAMES,
124
+ )
125
+ paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
126
+ return [
127
+ {"type": "text", "text": message["text"]},
128
+ *[{"type": "image", "image": path.as_posix()} for path in paths],
129
+ ]
130
+
131
+ return [
132
+ {"type": "text", "text": message["text"]},
133
+ *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
134
+ ]
135
+
136
+
137
+ def process_history(history: list[dict]) -> list[dict]:
138
+ messages = []
139
+ current_user_content: list[dict] = []
140
+ for item in history:
141
+ if item["role"] == "assistant":
142
+ if current_user_content:
143
+ messages.append({"role": "user", "content": current_user_content})
144
+ current_user_content = []
145
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
146
+ else:
147
+ content = item["content"]
148
+ if isinstance(content, str):
149
+ current_user_content.append({"type": "text", "text": content})
150
+ else:
151
+ filepath = content[0]
152
+ file_type = get_file_type(filepath)
153
+ current_user_content.append({"type": file_type, file_type: filepath})
154
+ return messages
155
+
156
+
157
+ @spaces.GPU()
158
+ @torch.inference_mode()
159
+ def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
160
+ if not validate_media_constraints(message):
161
+ yield ""
162
+ return
163
+
164
+ messages = []
165
+ if system_prompt:
166
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
167
+ messages.extend(process_history(history))
168
+ messages.append({"role": "user", "content": process_new_user_message(message)})
169
+
170
+ inputs = processor.apply_chat_template(
171
+ messages,
172
+ add_generation_prompt=True,
173
+ tokenize=True,
174
+ return_dict=True,
175
+ return_tensors="pt",
176
+ )
177
+ n_tokens = inputs["input_ids"].shape[1]
178
+ if n_tokens > MAX_INPUT_TOKENS:
179
+ gr.Warning(
180
+ f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space."
181
+ )
182
+ yield ""
183
+ return
184
+
185
+ inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
186
+
187
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
188
+ generate_kwargs = dict(
189
+ inputs,
190
+ streamer=streamer,
191
+ max_new_tokens=max_new_tokens,
192
+ do_sample=True,
193
+ temperature=1.0,
194
+ top_k=64,
195
+ top_p=0.95,
196
+ min_p=0.0,
197
+ repetition_penalty=1.0,
198
+ disable_compile=True,
199
+
200
+ )
201
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
202
+ t.start()
203
+
204
+ output = ""
205
+ for delta in streamer:
206
+ output += delta
207
+ yield output
208
+
209
+
210
+ # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
211
+ examples = [
212
+ ["What is the capital of France?", "You are a helpful assistant.", 700],
213
+ ["Explain quantum computing in simple terms", "You are a helpful assistant.", 512],
214
+ ["Write a short story about a robot learning to paint", "You are a helpful assistant.", 1000]
215
+ ]
216
+
217
+ system_prompt = (
218
+ "انت موديل عراقي ذكي من بغداد. تتحدث باللهجة العراقية فقط. "
219
+ "جاوب على كل سؤال بشرح كامل وموسع، ووضح الأسباب والخلفية والمعلومات المهمة. "
220
+ "استخدم أمثلة عراقية واقعية أو حياتية كلما أمكن. "
221
+ "تجنب الفصحى نهائيًا، وخلي الرد مطول وممتع."
222
+ )
223
+ # Create the chat interface
224
+ demo = gr.ChatInterface(
225
+ fn=generate,
226
+ type="messages",
227
+ textbox=gr.MultimodalTextbox(
228
+ file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
229
+ file_count="multiple",
230
+ autofocus=True,
231
+ ),
232
+ multimodal=True,
233
+ additional_inputs=[
234
+ gr.Textbox(label="System Prompt", value=system_prompt),
235
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2048, step=10, value=2048),
236
+ ],
237
+ title="Shako IRAQI AI",
238
+ examples=examples,
239
+ stop_btn=False,
240
+ css="""
241
+ .gradio-container, .chatbot, .chatbot * {
242
+ direction: rtl !important;
243
+ text-align: right !important;
244
+ unicode-bidi: plaintext !important;
245
+ font-family: 'Tajawal', 'Cairo', sans-serif;
246
+ }
247
+ """
248
+ )
249
+
250
+
251
+ if __name__ == "__main__":
252
+ demo.launch()