anaspro commited on
Commit
eef5702
·
verified ·
1 Parent(s): 8070fa9

Update app2.py

Browse files
Files changed (1) hide show
  1. app2.py +208 -159
app2.py CHANGED
@@ -1,179 +1,228 @@
1
- # -*- coding: utf-8 -*-
2
-
3
  import os
4
- import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
6
  import gradio as gr
7
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Load system prompt from file
10
- def load_system_prompt():
11
- try:
12
- with open('system_prompt.txt', 'r', encoding='utf-8') as f:
13
- return f.read().strip()
14
- except FileNotFoundError:
15
- return "أنت مساعد ذكي مفيد."
16
-
17
- DEFAULT_SYSTEM_PROMPT = load_system_prompt()
18
-
19
- model_path = "inceptionai/jais-adapted-7b-chat"
20
-
21
- # Jais chat prompts from documentation
22
- prompt_eng = """### Instruction:Your name is 'Jais', and you are named after Jebel Jais, the highest mountain in UAE. You were made by 'Inception' in the UAE. You are a helpful, respectful, and honest assistant. Always answer as helpfully as possible, while being safe. Complete the conversation between [|Human|] and [|AI|]:
23
- ### Input: [|Human|] {Question}
24
- [|AI|]
25
- ### Response :"""
26
-
27
- prompt_ar = """### Instruction:اسمك "جيس" وسميت على اسم جبل جيس اعلى جبل في الامارات. تم بنائك بواسطة Inception في الإمارات. أنت مساعد مفيد ومحترم وصادق. أجب دائمًا بأكبر قدر ممكن من المساعدة، مع الحفاظ على البقاء أمناً. أكمل المحادثة بين [|Human|] و[|AI|] :
28
- ### Input:[|Human|] {Question}
29
- [|AI|]
30
- ### Response :"""
31
-
32
- # إذا كان فيه HF_TOKEN في البيئة
33
- hf_token = os.getenv("HF_TOKEN")
34
-
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
-
37
- tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token)
38
- model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True, token=hf_token)
39
-
40
- if tokenizer.pad_token is None:
41
- tokenizer.pad_token = tokenizer.eos_token
42
-
43
- def get_response(text, tokenizer=tokenizer, model=model):
44
- """نفس الدالة من documentation مع تعديل لـ chat model"""
45
- tokenized = tokenizer(text, return_tensors="pt")
46
- input_ids, attention_mask = tokenized['input_ids'].to(device), tokenized['attention_mask'].to(device)
47
- input_len = input_ids.shape[-1]
48
- generate_ids = model.generate(
49
- input_ids,
50
- attention_mask=attention_mask,
51
- top_p=0.9,
52
- temperature=0.3,
53
- max_length=2048,
54
- min_length=input_len + 4,
55
- repetition_penalty=1.2,
56
- do_sample=True,
57
- pad_token_id=tokenizer.pad_token_id
58
- )
59
- response = tokenizer.batch_decode(
60
- generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
61
- )[0]
62
- response = response.split("### Response :")[-1].lstrip()
63
- return response
64
-
65
- def format_conversation_history(chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  messages = []
67
- for item in chat_history:
68
- role = item["role"]
69
- content = item["content"]
70
- if isinstance(content, list):
71
- content = content[0]["text"] if content and "text" in content[0] else str(content)
72
- messages.append({"role": role, "content": content})
 
 
 
 
 
 
 
 
 
73
  return messages
74
 
75
- def detect_language(text):
76
- """Simple language detection - Arabic vs English"""
77
- arabic_chars = sum(1 for char in text if '\u0600' <= char <= '\u06FF')
78
- total_chars = len(text.replace(' ', ''))
79
-
80
- if total_chars == 0:
81
- return 'ar' # default to Arabic
82
-
83
- arabic_ratio = arabic_chars / total_chars
84
- return 'ar' if arabic_ratio > 0.3 else 'en'
85
 
86
  @spaces.GPU()
87
- def generate_response(input_data, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
88
- # Detect language of the current question
89
- lang = detect_language(input_data)
90
- prompt_template = prompt_ar if lang == 'ar' else prompt_eng
91
-
92
- # Build conversation for Jais format
93
- conversation_parts = []
94
-
95
- # Add system prompt as part of the instruction (keep it short for Jais)
96
- system_instruction = "اسمك \"أليكس\" وأنت مساعد خدمة العملاء في شركة TechSolutions. مهمتك مساعدة العملاء في حل مشاكلهم مع المنتجات والإجابة عن أسئلتهم حول الخدمات. كن ودوداً وصبوراً ومحترماً. أجب بالعربية أو الإنجليزية حسب تفضيل العميل. ابدأ بالتحية وكن مباشراً في الحلول."
97
-
98
- # Add chat history
99
- if chat_history:
100
- for item in chat_history:
101
- role = item["role"]
102
- content = item["content"]
103
- if isinstance(content, list):
104
- content = content[0]["text"] if content and "text" in content[0] else str(content)
105
 
106
- if role == "user":
107
- conversation_parts.append(f"[|Human|] {content}")
108
- elif role == "assistant":
109
- conversation_parts.append(f"[|AI|] {content}")
110
-
111
- # Add current user message
112
- conversation_parts.append(f"[|Human|] {input_data}")
113
- conversation_parts.append("[|AI|]")
114
-
115
- # Join conversation
116
- conversation = "\n".join(conversation_parts)
117
-
118
- # Create full prompt using Jais format with our system prompt
119
- full_prompt = f"### Instruction:{system_instruction}\n### Input:{conversation}\n### Response :"
120
-
121
- try:
122
- # استخدام دالة get_response من documentation
123
- response = get_response(full_prompt)
124
-
125
- # استخراج الرد الجديد فقط (بعد "### Response :")
126
- if "### Response :" in response:
127
- response = response.split("### Response :")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- if not response:
130
- response = "أهلاً! أنا أليكس مساعد خدمة العملاء. كيف أقدر أساعدك اليوم؟"
 
 
131
 
132
- yield response
133
 
134
- except Exception as e:
135
- print(f"Error in generate_response: {e}")
136
- import traceback
137
- print(traceback.format_exc())
138
- yield "أهلاً! أنا أليكس مساعد خدمة العملاء. كيف أقدر أساعدك اليوم؟"
139
 
 
140
  demo = gr.ChatInterface(
141
- fn=generate_response,
142
- additional_inputs=[
143
- gr.Slider(label="الحد الأقصى للكلمات الجديدة", minimum=64, maximum=4096, step=1, value=2048),
144
- gr.Slider(label="درجة الحرارة", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
145
- gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
146
- gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
147
- gr.Slider(label="عقوبة التكرار", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
148
- ],
149
- examples=[
150
- [{"text": "النت عندي معطل من الصبح، تقدر تساعدني؟"}],
151
- [{"text": "عندي مشكلة بالاتصال بالواي فاي"}],
152
- [{"text": "شنو الباقات المتوفرة عندكم؟"}],
153
- [{"text": "كيف أعيد ضبط الجهاز؟"}],
154
- [{"text": "My device is not working properly"}],
155
- ],
156
- cache_examples=False,
157
  type="messages",
158
- title="دعم عملاء TechSolutions - مساعد أليكس (العراقي)",
159
- description="""🤖 مساعد خدمة عملاء ذكي لـ TechSolutions
160
-
161
- ✨ المميزات:
162
- - 🌐 دعم ثنائي اللغة (عربي وإنجليزي)
163
- - 💬 لهجة محادثة طبيعية
164
- - 🔧 دعم فني واستكشاف الأخطاء
165
- - 📋 معلومات الخدمات والإرشاد
166
- - 🎯 مدعوم بـ موديل Unsloth Meta-Llama-3.1-8B-Instruct (مع تحسينات الأداء)
167
-
168
- احجي مع أليكس لحل مشاكلك التقنية، استفسر عن الخدمات، أو احصل على معلومات المنتجات.""",
169
- fill_height=True,
170
- textbox=gr.Textbox(
171
- label="اكتب رسالتك هنا",
172
- placeholder="مثال: عندي مشكلة بالجهاز..."
173
  ),
174
- stop_btn="إيقاف التوليد",
175
- multimodal=False,
176
- theme=gr.themes.Soft()
 
 
 
 
 
177
  )
178
 
179
  if __name__ == "__main__":
 
 
 
1
  import os
2
+ import pathlib
3
+ import tempfile
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
+
7
+ import av
8
  import gradio as gr
9
  import spaces
10
+ import torch
11
+ from transformers import AutoModelForImageTextToText, AutoProcessor
12
+ from transformers.generation.streamers import TextIteratorStreamer
13
+
14
+ # Model configuration
15
+ model_id = "anaspro/Shako-4B-it-v2"
16
+ processor = AutoProcessor.from_pretrained(model_id)
17
+ model = AutoModelForImageTextToText.from_pretrained(
18
+ model_id,
19
+ device_map="auto",
20
+ torch_dtype=torch.bfloat16
21
+ )
22
 
23
+ # Supported file types
24
+ IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
25
+ VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
26
+ AUDIO_FILE_TYPES = (".mp3", ".wav")
27
+
28
+ # Video processing settings
29
+ TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
30
+ MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
31
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
32
+
33
+
34
+ def get_file_type(path: str) -> str:
35
+ if path.endswith(IMAGE_FILE_TYPES):
36
+ return "image"
37
+ if path.endswith(VIDEO_FILE_TYPES):
38
+ return "video"
39
+ if path.endswith(AUDIO_FILE_TYPES):
40
+ return "audio"
41
+ error_message = f"Unsupported file type: {path}"
42
+ raise ValueError(error_message)
43
+
44
+
45
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
46
+ video_count = 0
47
+ non_video_count = 0
48
+ for path in paths:
49
+ if path.endswith(VIDEO_FILE_TYPES):
50
+ video_count += 1
51
+ else:
52
+ non_video_count += 1
53
+ return video_count, non_video_count
54
+
55
+
56
+ def validate_media_constraints(message: dict) -> bool:
57
+ video_count, non_video_count = count_files_in_new_message(message["files"])
58
+ if video_count > 1:
59
+ gr.Warning("Only one video is supported.")
60
+ return False
61
+ if video_count == 1 and non_video_count > 0:
62
+ gr.Warning("Mixing images and videos is not allowed.")
63
+ return False
64
+ return True
65
+
66
+
67
+ def extract_frames_to_tempdir(
68
+ video_path: str,
69
+ target_fps: float,
70
+ max_frames: int | None = None,
71
+ parent_dir: str | None = None,
72
+ prefix: str = "frames_",
73
+ ) -> str:
74
+ temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
75
+
76
+ container = av.open(video_path)
77
+ video_stream = container.streams.video[0]
78
+
79
+ if video_stream.duration is None or video_stream.time_base is None:
80
+ raise ValueError("video_stream is missing duration or time_base")
81
+
82
+ time_base = video_stream.time_base
83
+ duration = float(video_stream.duration * time_base)
84
+ interval = 1.0 / target_fps
85
+
86
+ total_frames = int(duration * target_fps)
87
+ if max_frames is not None:
88
+ total_frames = min(total_frames, max_frames)
89
+
90
+ target_times = [i * interval for i in range(total_frames)]
91
+ target_index = 0
92
+
93
+ for frame in container.decode(video=0):
94
+ if frame.pts is None:
95
+ continue
96
+
97
+ timestamp = float(frame.pts * time_base)
98
+
99
+ if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
100
+ frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
101
+ frame.to_image().save(frame_path)
102
+ target_index += 1
103
+
104
+ if max_frames is not None and target_index >= max_frames:
105
+ break
106
+
107
+ container.close()
108
+ return temp_dir
109
+
110
+
111
+ def process_new_user_message(message: dict) -> list[dict]:
112
+ if not message["files"]:
113
+ return [{"type": "text", "text": message["text"]}]
114
+
115
+ file_types = [get_file_type(path) for path in message["files"]]
116
+
117
+ if len(file_types) == 1 and file_types[0] == "video":
118
+ gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
119
+
120
+ temp_dir = extract_frames_to_tempdir(
121
+ message["files"][0],
122
+ target_fps=TARGET_FPS,
123
+ max_frames=MAX_FRAMES,
124
+ )
125
+ paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
126
+ return [
127
+ {"type": "text", "text": message["text"]},
128
+ *[{"type": "image", "image": path.as_posix()} for path in paths],
129
+ ]
130
+
131
+ return [
132
+ {"type": "text", "text": message["text"]},
133
+ *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
134
+ ]
135
+
136
+
137
+ def process_history(history: list[dict]) -> list[dict]:
138
  messages = []
139
+ current_user_content: list[dict] = []
140
+ for item in history:
141
+ if item["role"] == "assistant":
142
+ if current_user_content:
143
+ messages.append({"role": "user", "content": current_user_content})
144
+ current_user_content = []
145
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
146
+ else:
147
+ content = item["content"]
148
+ if isinstance(content, str):
149
+ current_user_content.append({"type": "text", "text": content})
150
+ else:
151
+ filepath = content[0]
152
+ file_type = get_file_type(filepath)
153
+ current_user_content.append({"type": file_type, file_type: filepath})
154
  return messages
155
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  @spaces.GPU()
158
+ @torch.inference_mode()
159
+ def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
160
+ if not validate_media_constraints(message):
161
+ yield ""
162
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ messages = []
165
+ if system_prompt:
166
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
167
+ messages.extend(process_history(history))
168
+ messages.append({"role": "user", "content": process_new_user_message(message)})
169
+
170
+ inputs = processor.apply_chat_template(
171
+ messages,
172
+ add_generation_prompt=True,
173
+ tokenize=True,
174
+ return_dict=True,
175
+ return_tensors="pt",
176
+ )
177
+ n_tokens = inputs["input_ids"].shape[1]
178
+ if n_tokens > MAX_INPUT_TOKENS:
179
+ gr.Warning(
180
+ f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space."
181
+ )
182
+ yield ""
183
+ return
184
+
185
+ inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
186
+
187
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
188
+ generate_kwargs = dict(
189
+ inputs,
190
+ streamer=streamer,
191
+ max_new_tokens=max_new_tokens,
192
+ do_sample=False,
193
+ disable_compile=True,
194
+ )
195
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
196
+ t.start()
197
 
198
+ output = ""
199
+ for delta in streamer:
200
+ output += delta
201
+ yield output
202
 
 
203
 
204
+ # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
205
+ examples = [
206
+ ["انت موديل عراقي تحكي هعراقي فقط وتكون ترفيهي", 700]
207
+ ]
 
208
 
209
+ # Create the chat interface
210
  demo = gr.ChatInterface(
211
+ fn=generate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  type="messages",
213
+ textbox=gr.MultimodalTextbox(
214
+ file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
215
+ file_count="multiple",
216
+ autofocus=True,
 
 
 
 
 
 
 
 
 
 
 
217
  ),
218
+ multimodal=True,
219
+ additional_inputs=[
220
+ gr.Textbox(label="System Prompt", value="انت ذكاء صناعي يتحدث باللهجة العراقية بس ما تستخدم فصحى ابدا"),
221
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
222
+ ],
223
+ title="Shako IRAQI AI",
224
+ examples=examples,
225
+ stop_btn=False,
226
  )
227
 
228
  if __name__ == "__main__":