Update app2.py
Browse files
app2.py
CHANGED
|
@@ -1,179 +1,228 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
-
import
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
messages = []
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
return messages
|
| 74 |
|
| 75 |
-
def detect_language(text):
|
| 76 |
-
"""Simple language detection - Arabic vs English"""
|
| 77 |
-
arabic_chars = sum(1 for char in text if '\u0600' <= char <= '\u06FF')
|
| 78 |
-
total_chars = len(text.replace(' ', ''))
|
| 79 |
-
|
| 80 |
-
if total_chars == 0:
|
| 81 |
-
return 'ar' # default to Arabic
|
| 82 |
-
|
| 83 |
-
arabic_ratio = arabic_chars / total_chars
|
| 84 |
-
return 'ar' if arabic_ratio > 0.3 else 'en'
|
| 85 |
|
| 86 |
@spaces.GPU()
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
# Build conversation for Jais format
|
| 93 |
-
conversation_parts = []
|
| 94 |
-
|
| 95 |
-
# Add system prompt as part of the instruction (keep it short for Jais)
|
| 96 |
-
system_instruction = "اسمك \"أليكس\" وأنت مساعد خدمة العملاء في شركة TechSolutions. مهمتك مساعدة العملاء في حل مشاكلهم مع المنتجات والإجابة عن أسئلتهم حول الخدمات. كن ودوداً وصبوراً ومحترماً. أجب بالعربية أو الإنجليزية حسب تفضيل العميل. ابدأ بالتحية وكن مباشراً في الحلول."
|
| 97 |
-
|
| 98 |
-
# Add chat history
|
| 99 |
-
if chat_history:
|
| 100 |
-
for item in chat_history:
|
| 101 |
-
role = item["role"]
|
| 102 |
-
content = item["content"]
|
| 103 |
-
if isinstance(content, list):
|
| 104 |
-
content = content[0]["text"] if content and "text" in content[0] else str(content)
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
yield response
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
yield "أهلاً! أنا أليكس مساعد خدمة العملاء. كيف أقدر أساعدك اليوم؟"
|
| 139 |
|
|
|
|
| 140 |
demo = gr.ChatInterface(
|
| 141 |
-
fn=
|
| 142 |
-
additional_inputs=[
|
| 143 |
-
gr.Slider(label="الحد الأقصى للكلمات الجديدة", minimum=64, maximum=4096, step=1, value=2048),
|
| 144 |
-
gr.Slider(label="درجة الحرارة", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
|
| 145 |
-
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
|
| 146 |
-
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
|
| 147 |
-
gr.Slider(label="عقوبة التكرار", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
|
| 148 |
-
],
|
| 149 |
-
examples=[
|
| 150 |
-
[{"text": "النت عندي معطل من الصبح، تقدر تساعدني؟"}],
|
| 151 |
-
[{"text": "عندي مشكلة بالاتصال بالواي فاي"}],
|
| 152 |
-
[{"text": "شنو الباقات المتوفرة عندكم؟"}],
|
| 153 |
-
[{"text": "كيف أعيد ضبط الجهاز؟"}],
|
| 154 |
-
[{"text": "My device is not working properly"}],
|
| 155 |
-
],
|
| 156 |
-
cache_examples=False,
|
| 157 |
type="messages",
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
- 🌐 دعم ثنائي اللغة (عربي وإنجليزي)
|
| 163 |
-
- 💬 لهجة محادثة طبيعية
|
| 164 |
-
- 🔧 دعم فني واستكشاف الأخطاء
|
| 165 |
-
- 📋 معلومات الخدمات والإرشاد
|
| 166 |
-
- 🎯 مدعوم بـ موديل Unsloth Meta-Llama-3.1-8B-Instruct (مع تحسينات الأداء)
|
| 167 |
-
|
| 168 |
-
احجي مع أليكس لحل مشاكلك التقنية، استفسر عن الخدمات، أو احصل على معلومات المنتجات.""",
|
| 169 |
-
fill_height=True,
|
| 170 |
-
textbox=gr.Textbox(
|
| 171 |
-
label="اكتب رسالتك هنا",
|
| 172 |
-
placeholder="مثال: عندي مشكلة بالجهاز..."
|
| 173 |
),
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
|
| 179 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import tempfile
|
| 4 |
+
from collections.abc import Iterator
|
| 5 |
+
from threading import Thread
|
| 6 |
+
|
| 7 |
+
import av
|
| 8 |
import gradio as gr
|
| 9 |
import spaces
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
| 12 |
+
from transformers.generation.streamers import TextIteratorStreamer
|
| 13 |
+
|
| 14 |
+
# Model configuration
|
| 15 |
+
model_id = "anaspro/Shako-4B-it-v2"
|
| 16 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 17 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
| 18 |
+
model_id,
|
| 19 |
+
device_map="auto",
|
| 20 |
+
torch_dtype=torch.bfloat16
|
| 21 |
+
)
|
| 22 |
|
| 23 |
+
# Supported file types
|
| 24 |
+
IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
|
| 25 |
+
VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
|
| 26 |
+
AUDIO_FILE_TYPES = (".mp3", ".wav")
|
| 27 |
+
|
| 28 |
+
# Video processing settings
|
| 29 |
+
TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
|
| 30 |
+
MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
|
| 31 |
+
MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_file_type(path: str) -> str:
|
| 35 |
+
if path.endswith(IMAGE_FILE_TYPES):
|
| 36 |
+
return "image"
|
| 37 |
+
if path.endswith(VIDEO_FILE_TYPES):
|
| 38 |
+
return "video"
|
| 39 |
+
if path.endswith(AUDIO_FILE_TYPES):
|
| 40 |
+
return "audio"
|
| 41 |
+
error_message = f"Unsupported file type: {path}"
|
| 42 |
+
raise ValueError(error_message)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
| 46 |
+
video_count = 0
|
| 47 |
+
non_video_count = 0
|
| 48 |
+
for path in paths:
|
| 49 |
+
if path.endswith(VIDEO_FILE_TYPES):
|
| 50 |
+
video_count += 1
|
| 51 |
+
else:
|
| 52 |
+
non_video_count += 1
|
| 53 |
+
return video_count, non_video_count
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def validate_media_constraints(message: dict) -> bool:
|
| 57 |
+
video_count, non_video_count = count_files_in_new_message(message["files"])
|
| 58 |
+
if video_count > 1:
|
| 59 |
+
gr.Warning("Only one video is supported.")
|
| 60 |
+
return False
|
| 61 |
+
if video_count == 1 and non_video_count > 0:
|
| 62 |
+
gr.Warning("Mixing images and videos is not allowed.")
|
| 63 |
+
return False
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def extract_frames_to_tempdir(
|
| 68 |
+
video_path: str,
|
| 69 |
+
target_fps: float,
|
| 70 |
+
max_frames: int | None = None,
|
| 71 |
+
parent_dir: str | None = None,
|
| 72 |
+
prefix: str = "frames_",
|
| 73 |
+
) -> str:
|
| 74 |
+
temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
|
| 75 |
+
|
| 76 |
+
container = av.open(video_path)
|
| 77 |
+
video_stream = container.streams.video[0]
|
| 78 |
+
|
| 79 |
+
if video_stream.duration is None or video_stream.time_base is None:
|
| 80 |
+
raise ValueError("video_stream is missing duration or time_base")
|
| 81 |
+
|
| 82 |
+
time_base = video_stream.time_base
|
| 83 |
+
duration = float(video_stream.duration * time_base)
|
| 84 |
+
interval = 1.0 / target_fps
|
| 85 |
+
|
| 86 |
+
total_frames = int(duration * target_fps)
|
| 87 |
+
if max_frames is not None:
|
| 88 |
+
total_frames = min(total_frames, max_frames)
|
| 89 |
+
|
| 90 |
+
target_times = [i * interval for i in range(total_frames)]
|
| 91 |
+
target_index = 0
|
| 92 |
+
|
| 93 |
+
for frame in container.decode(video=0):
|
| 94 |
+
if frame.pts is None:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
timestamp = float(frame.pts * time_base)
|
| 98 |
+
|
| 99 |
+
if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
|
| 100 |
+
frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
|
| 101 |
+
frame.to_image().save(frame_path)
|
| 102 |
+
target_index += 1
|
| 103 |
+
|
| 104 |
+
if max_frames is not None and target_index >= max_frames:
|
| 105 |
+
break
|
| 106 |
+
|
| 107 |
+
container.close()
|
| 108 |
+
return temp_dir
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def process_new_user_message(message: dict) -> list[dict]:
|
| 112 |
+
if not message["files"]:
|
| 113 |
+
return [{"type": "text", "text": message["text"]}]
|
| 114 |
+
|
| 115 |
+
file_types = [get_file_type(path) for path in message["files"]]
|
| 116 |
+
|
| 117 |
+
if len(file_types) == 1 and file_types[0] == "video":
|
| 118 |
+
gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
|
| 119 |
+
|
| 120 |
+
temp_dir = extract_frames_to_tempdir(
|
| 121 |
+
message["files"][0],
|
| 122 |
+
target_fps=TARGET_FPS,
|
| 123 |
+
max_frames=MAX_FRAMES,
|
| 124 |
+
)
|
| 125 |
+
paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
|
| 126 |
+
return [
|
| 127 |
+
{"type": "text", "text": message["text"]},
|
| 128 |
+
*[{"type": "image", "image": path.as_posix()} for path in paths],
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
return [
|
| 132 |
+
{"type": "text", "text": message["text"]},
|
| 133 |
+
*[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def process_history(history: list[dict]) -> list[dict]:
|
| 138 |
messages = []
|
| 139 |
+
current_user_content: list[dict] = []
|
| 140 |
+
for item in history:
|
| 141 |
+
if item["role"] == "assistant":
|
| 142 |
+
if current_user_content:
|
| 143 |
+
messages.append({"role": "user", "content": current_user_content})
|
| 144 |
+
current_user_content = []
|
| 145 |
+
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
|
| 146 |
+
else:
|
| 147 |
+
content = item["content"]
|
| 148 |
+
if isinstance(content, str):
|
| 149 |
+
current_user_content.append({"type": "text", "text": content})
|
| 150 |
+
else:
|
| 151 |
+
filepath = content[0]
|
| 152 |
+
file_type = get_file_type(filepath)
|
| 153 |
+
current_user_content.append({"type": file_type, file_type: filepath})
|
| 154 |
return messages
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
@spaces.GPU()
|
| 158 |
+
@torch.inference_mode()
|
| 159 |
+
def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
| 160 |
+
if not validate_media_constraints(message):
|
| 161 |
+
yield ""
|
| 162 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
messages = []
|
| 165 |
+
if system_prompt:
|
| 166 |
+
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
|
| 167 |
+
messages.extend(process_history(history))
|
| 168 |
+
messages.append({"role": "user", "content": process_new_user_message(message)})
|
| 169 |
+
|
| 170 |
+
inputs = processor.apply_chat_template(
|
| 171 |
+
messages,
|
| 172 |
+
add_generation_prompt=True,
|
| 173 |
+
tokenize=True,
|
| 174 |
+
return_dict=True,
|
| 175 |
+
return_tensors="pt",
|
| 176 |
+
)
|
| 177 |
+
n_tokens = inputs["input_ids"].shape[1]
|
| 178 |
+
if n_tokens > MAX_INPUT_TOKENS:
|
| 179 |
+
gr.Warning(
|
| 180 |
+
f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space."
|
| 181 |
+
)
|
| 182 |
+
yield ""
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
|
| 186 |
+
|
| 187 |
+
streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
|
| 188 |
+
generate_kwargs = dict(
|
| 189 |
+
inputs,
|
| 190 |
+
streamer=streamer,
|
| 191 |
+
max_new_tokens=max_new_tokens,
|
| 192 |
+
do_sample=False,
|
| 193 |
+
disable_compile=True,
|
| 194 |
+
)
|
| 195 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 196 |
+
t.start()
|
| 197 |
|
| 198 |
+
output = ""
|
| 199 |
+
for delta in streamer:
|
| 200 |
+
output += delta
|
| 201 |
+
yield output
|
| 202 |
|
|
|
|
| 203 |
|
| 204 |
+
# Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
|
| 205 |
+
examples = [
|
| 206 |
+
["انت موديل عراقي تحكي هعراقي فقط وتكون ترفيهي", 700]
|
| 207 |
+
]
|
|
|
|
| 208 |
|
| 209 |
+
# Create the chat interface
|
| 210 |
demo = gr.ChatInterface(
|
| 211 |
+
fn=generate,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
type="messages",
|
| 213 |
+
textbox=gr.MultimodalTextbox(
|
| 214 |
+
file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
|
| 215 |
+
file_count="multiple",
|
| 216 |
+
autofocus=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
),
|
| 218 |
+
multimodal=True,
|
| 219 |
+
additional_inputs=[
|
| 220 |
+
gr.Textbox(label="System Prompt", value="انت ذكاء صناعي يتحدث باللهجة العراقية بس ما تستخدم فصحى ابدا"),
|
| 221 |
+
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
|
| 222 |
+
],
|
| 223 |
+
title="Shako IRAQI AI",
|
| 224 |
+
examples=examples,
|
| 225 |
+
stop_btn=False,
|
| 226 |
)
|
| 227 |
|
| 228 |
if __name__ == "__main__":
|