|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor |
|
|
from qwen_omni_utils import process_mm_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _patched_mark_tied_weights_as_initialized(self): |
|
|
""" |
|
|
تجاوز مشكلة lm_head في tied weights |
|
|
""" |
|
|
return |
|
|
|
|
|
def _patched_init_weights(self, module): |
|
|
""" |
|
|
تجاوز مشكلة initializer_range في Qwen3OmniMoeTalkerConfig |
|
|
""" |
|
|
|
|
|
try: |
|
|
std = self.config.initializer_range |
|
|
except AttributeError: |
|
|
|
|
|
std = 0.02 |
|
|
|
|
|
|
|
|
if isinstance(module, torch.nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, torch.nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def _patched_initialize_weights(self): |
|
|
""" |
|
|
تجاوز كامل لدالة initialize_weights في حالة استمرار المشاكل |
|
|
""" |
|
|
|
|
|
try: |
|
|
|
|
|
if hasattr(self, '_original_initialize_weights'): |
|
|
self._original_initialize_weights() |
|
|
else: |
|
|
|
|
|
for module in self.modules(): |
|
|
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): |
|
|
if hasattr(module, 'weight'): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if hasattr(module, 'bias') and module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not initialize weights properly: {e}") |
|
|
|
|
|
|
|
|
|
|
|
def apply_patches(): |
|
|
"""تطبيق جميع الـ patches اللازمة""" |
|
|
|
|
|
|
|
|
if hasattr(Qwen3OmniMoeForConditionalGeneration, "mark_tied_weights_as_initialized"): |
|
|
Qwen3OmniMoeForConditionalGeneration.mark_tied_weights_as_initialized = ( |
|
|
_patched_mark_tied_weights_as_initialized |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(Qwen3OmniMoeForConditionalGeneration, "_init_weights"): |
|
|
Qwen3OmniMoeForConditionalGeneration._init_weights = _patched_init_weights |
|
|
|
|
|
|
|
|
if hasattr(Qwen3OmniMoeForConditionalGeneration, "initialize_weights"): |
|
|
|
|
|
Qwen3OmniMoeForConditionalGeneration._original_initialize_weights = ( |
|
|
Qwen3OmniMoeForConditionalGeneration.initialize_weights |
|
|
) |
|
|
Qwen3OmniMoeForConditionalGeneration.initialize_weights = _patched_initialize_weights |
|
|
|
|
|
|
|
|
apply_patches() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = os.getenv("MODEL_PATH", "Qwen/Qwen3-Omni-30B-A3B-Instruct") |
|
|
USE_AUDIO_IN_VIDEO = True |
|
|
|
|
|
VOICE_CHOICES = ["Ethan", "Chelsie", "Aiden"] |
|
|
DEFAULT_VOICE = "Ethan" |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
""" |
|
|
تحميل Qwen3-Omni والمعالج عند أول استدعاء فقط. |
|
|
- نستخدم attn_implementation="eager" لتفادي الحاجة لـ flash-attn. |
|
|
- نضيف low_cpu_mem_usage=True لتحسين الأداء |
|
|
- نضيف ignore_mismatched_sizes=True لتجاوز مشاكل الأحجام |
|
|
""" |
|
|
global model, processor |
|
|
|
|
|
if model is not None and processor is not None: |
|
|
return |
|
|
|
|
|
print(f"[ZeroGPU] Loading model from: {MODEL_PATH}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch_dtype = torch.bfloat16 |
|
|
device = "cuda" |
|
|
else: |
|
|
torch_dtype = torch.float32 |
|
|
device = "cpu" |
|
|
|
|
|
try: |
|
|
|
|
|
local_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( |
|
|
MODEL_PATH, |
|
|
torch_dtype=torch_dtype, |
|
|
attn_implementation="eager", |
|
|
low_cpu_mem_usage=True, |
|
|
ignore_mismatched_sizes=True, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
local_model = local_model.to(device) |
|
|
|
|
|
|
|
|
local_model.eval() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
print("Attempting alternative loading method...") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
original_init_weights = None |
|
|
if hasattr(Qwen3OmniMoeForConditionalGeneration, "_init_weights"): |
|
|
original_init_weights = Qwen3OmniMoeForConditionalGeneration._init_weights |
|
|
Qwen3OmniMoeForConditionalGeneration._init_weights = lambda self, module: None |
|
|
|
|
|
local_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( |
|
|
MODEL_PATH, |
|
|
torch_dtype=torch_dtype, |
|
|
attn_implementation="eager", |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
|
|
|
if original_init_weights: |
|
|
Qwen3OmniMoeForConditionalGeneration._init_weights = original_init_weights |
|
|
|
|
|
local_model = local_model.to(device) |
|
|
local_model.eval() |
|
|
|
|
|
except Exception as e2: |
|
|
raise RuntimeError(f"Failed to load model: {e2}") |
|
|
|
|
|
|
|
|
try: |
|
|
local_processor = Qwen3OmniMoeProcessor.from_pretrained( |
|
|
MODEL_PATH, |
|
|
trust_remote_code=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error loading processor: {e}") |
|
|
raise |
|
|
|
|
|
model = local_model |
|
|
processor = local_processor |
|
|
print(f"[ZeroGPU] Model loaded successfully on {device} with dtype {torch_dtype}.") |
|
|
|
|
|
|
|
|
def build_messages_from_history( |
|
|
history, |
|
|
system_prompt, |
|
|
user_text, |
|
|
image, |
|
|
audio_path, |
|
|
video_path, |
|
|
): |
|
|
""" |
|
|
تحويل تاريخ الدردشة + المدخل الحالي إلى conversation بالـ format |
|
|
المطلوب من Qwen3-Omni. |
|
|
history: list of [user_text, assistant_text] |
|
|
""" |
|
|
messages = [] |
|
|
|
|
|
if system_prompt: |
|
|
messages.append( |
|
|
{ |
|
|
"role": "system", |
|
|
"content": [{"type": "text", "text": system_prompt}], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
for user_msg, assistant_msg in history: |
|
|
if user_msg: |
|
|
messages.append( |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [{"type": "text", "text": user_msg}], |
|
|
} |
|
|
) |
|
|
if assistant_msg: |
|
|
messages.append( |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": [{"type": "text", "text": assistant_msg}], |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
user_content = [] |
|
|
|
|
|
if image is not None: |
|
|
user_content.append({"type": "image", "image": image}) |
|
|
|
|
|
if audio_path is not None and audio_path != "": |
|
|
user_content.append({"type": "audio", "audio": audio_path}) |
|
|
|
|
|
if video_path is not None and video_path != "": |
|
|
user_content.append({"type": "video", "video": video_path}) |
|
|
|
|
|
if user_text and user_text.strip(): |
|
|
user_content.append({"type": "text", "text": user_text.strip()}) |
|
|
|
|
|
if user_content: |
|
|
messages.append( |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_content, |
|
|
} |
|
|
) |
|
|
|
|
|
return messages |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def qwen3_omni_inference( |
|
|
history, |
|
|
user_text, |
|
|
image, |
|
|
audio_path, |
|
|
video_path, |
|
|
system_prompt, |
|
|
return_audio, |
|
|
speaker, |
|
|
temperature, |
|
|
top_p, |
|
|
max_tokens, |
|
|
): |
|
|
""" |
|
|
- تنفيذ الاستدلال على ZeroGPU. |
|
|
- يدعم نص + صورة + صوت + فيديو في نفس الرسالة. |
|
|
- مخرج نصي دائماً، ومخرج صوتي اختياري. |
|
|
""" |
|
|
|
|
|
|
|
|
if not (user_text or image is not None or audio_path or video_path): |
|
|
return history, None, "", None, None, None |
|
|
|
|
|
try: |
|
|
load_model() |
|
|
global model, processor |
|
|
|
|
|
messages = build_messages_from_history( |
|
|
history=history, |
|
|
system_prompt=system_prompt, |
|
|
user_text=user_text, |
|
|
image=image, |
|
|
audio_path=audio_path, |
|
|
video_path=video_path, |
|
|
) |
|
|
|
|
|
|
|
|
text_prompt = processor.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=False, |
|
|
) |
|
|
|
|
|
|
|
|
audios, images, videos = process_mm_info( |
|
|
messages, |
|
|
use_audio_in_video=USE_AUDIO_IN_VIDEO, |
|
|
) |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
text=text_prompt, |
|
|
audio=audios, |
|
|
images=images, |
|
|
videos=videos, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
use_audio_in_video=USE_AUDIO_IN_VIDEO, |
|
|
) |
|
|
|
|
|
|
|
|
first_param = next(model.parameters()) |
|
|
device = first_param.device |
|
|
|
|
|
|
|
|
for key in inputs: |
|
|
if hasattr(inputs[key], 'to'): |
|
|
inputs[key] = inputs[key].to(device) |
|
|
|
|
|
|
|
|
gen_kwargs = dict( |
|
|
temperature=float(temperature) if temperature > 0 else 1e-7, |
|
|
top_p=float(top_p), |
|
|
max_new_tokens=int(max_tokens), |
|
|
do_sample=temperature > 0, |
|
|
use_audio_in_video=USE_AUDIO_IN_VIDEO, |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(model, 'config') and hasattr(model.config, 'thinker_return_dict_in_generate'): |
|
|
gen_kwargs["thinker_return_dict_in_generate"] = True |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if not return_audio: |
|
|
gen_kwargs["return_audio"] = False |
|
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
|
text_ids = outputs |
|
|
audio_out = None |
|
|
else: |
|
|
gen_kwargs["speaker"] = speaker |
|
|
gen_kwargs["return_audio"] = True |
|
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
|
if isinstance(outputs, tuple): |
|
|
text_ids, audio_out = outputs |
|
|
else: |
|
|
text_ids = outputs |
|
|
audio_out = None |
|
|
|
|
|
|
|
|
input_len = inputs["input_ids"].shape[1] |
|
|
|
|
|
|
|
|
if hasattr(text_ids, 'sequences'): |
|
|
generated_ids = text_ids.sequences[:, input_len:] |
|
|
else: |
|
|
generated_ids = text_ids[:, input_len:] |
|
|
|
|
|
generated_text = processor.batch_decode( |
|
|
generated_ids, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False, |
|
|
)[0] |
|
|
|
|
|
|
|
|
user_display = ( |
|
|
user_text if (user_text and user_text.strip()) else "[Multimodal message]" |
|
|
) |
|
|
history = history + [[user_display, generated_text]] |
|
|
|
|
|
|
|
|
gr_audio = None |
|
|
if audio_out is not None: |
|
|
try: |
|
|
audio_np = audio_out.reshape(-1).detach().cpu().numpy() |
|
|
sample_rate = 24000 |
|
|
gr_audio = (sample_rate, audio_np.astype(np.float32)) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not process audio output: {e}") |
|
|
gr_audio = None |
|
|
|
|
|
|
|
|
return history, gr_audio, "", None, None, None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during inference: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
user_display = ( |
|
|
user_text if (user_text and user_text.strip()) else "[Multimodal message]" |
|
|
) |
|
|
error_message = f"عذراً، حدث خطأ أثناء معالجة الرسالة: {str(e)}" |
|
|
history = history + [[user_display, error_message]] |
|
|
return history, None, "", None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_chat(): |
|
|
"""إعادة تعيين المحادثة ومخرج الصوت.""" |
|
|
return [], None |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
with gr.Blocks( |
|
|
title="Qwen3-Omni-30B-A3B – ZeroGPU Chat", |
|
|
theme=gr.themes.Soft(), |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
<h1 style="text-align:center;">🤖 Qwen3-Omni-30B-A3B – ZeroGPU Chat</h1> |
|
|
<p style="text-align:center;"> |
|
|
دردشة متعددة الوسائط (نص + صورة + صوت + فيديو) تعمل على ZeroGPU.<br/> |
|
|
اكتب رسالتك، ويمكنك إضافة صورة/صوت/فيديو، ثم اضغط <b>إرسال</b> أو Enter.<br/> |
|
|
(لإضافة سطر جديد استخدم Shift+Enter) |
|
|
</p> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot( |
|
|
label="المحادثة", |
|
|
height=480, |
|
|
elem_id="chatbot", |
|
|
) |
|
|
|
|
|
audio_output = gr.Audio( |
|
|
label="رد النموذج (صوت)", |
|
|
type="numpy", |
|
|
autoplay=True, |
|
|
visible=True, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
user_text = gr.Textbox( |
|
|
label="رسالتك", |
|
|
placeholder="اكتب رسالتك هنا (يمكنك أيضاً إرفاق صورة/صوت/فيديو من الأسفل)...", |
|
|
lines=3, |
|
|
show_label=False, |
|
|
elem_id="message", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
image_input = gr.Image( |
|
|
label="📷 صورة (اختياري)", |
|
|
type="pil", |
|
|
sources=["upload", "webcam"], |
|
|
height=150, |
|
|
) |
|
|
audio_input = gr.Audio( |
|
|
label="🎙️ صوت (اختياري)", |
|
|
type="filepath", |
|
|
sources=["microphone", "upload"], |
|
|
) |
|
|
video_input = gr.Video( |
|
|
label="🎬 فيديو (اختياري)", |
|
|
height=150, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
send_btn = gr.Button("إرسال", variant="primary", scale=2) |
|
|
clear_btn = gr.Button("مسح المحادثة", variant="secondary") |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### ⚙️ إعدادات النموذج") |
|
|
|
|
|
system_prompt = gr.Textbox( |
|
|
label="System Prompt", |
|
|
value="You are a helpful, multilingual assistant.", |
|
|
lines=4, |
|
|
placeholder="يمكنك التحكم في شخصية النموذج من هنا (اختياري).", |
|
|
) |
|
|
|
|
|
return_audio = gr.Checkbox( |
|
|
label="تفعيل مخرج صوتي (النموذج يتكلم)؟", |
|
|
value=True, |
|
|
) |
|
|
|
|
|
speaker = gr.Dropdown( |
|
|
label="صوت المتحدث (speaker)", |
|
|
choices=VOICE_CHOICES, |
|
|
value=DEFAULT_VOICE, |
|
|
) |
|
|
|
|
|
with gr.Accordion("إعدادات متقدمة", open=False): |
|
|
temperature = gr.Slider( |
|
|
label="Temperature (العشوائية)", |
|
|
minimum=0.0, |
|
|
maximum=1.5, |
|
|
value=0.6, |
|
|
step=0.05, |
|
|
) |
|
|
|
|
|
top_p = gr.Slider( |
|
|
label="Top-p (حجم العينة)", |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.95, |
|
|
step=0.05, |
|
|
) |
|
|
|
|
|
max_tokens = gr.Slider( |
|
|
label="Max new tokens (طول الرد الأقصى)", |
|
|
minimum=16, |
|
|
maximum=1024, |
|
|
value=384, |
|
|
step=16, |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
**📝 ملاحظات:** |
|
|
- يمكنك إرسال نص فقط، أو نص مع صورة/صوت/فيديو في رسالة واحدة |
|
|
- Enter للإرسال، Shift+Enter لسطر جديد |
|
|
- تشغيل النموذج على ZeroGPU قد يستغرق عدة ثوانٍ حسب طول الرسالة |
|
|
- النموذج يدعم اللغات المتعددة بما فيها العربية والإنجليزية |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
history_state = gr.State([]) |
|
|
|
|
|
|
|
|
send_inputs = [ |
|
|
history_state, |
|
|
user_text, |
|
|
image_input, |
|
|
audio_input, |
|
|
video_input, |
|
|
system_prompt, |
|
|
return_audio, |
|
|
speaker, |
|
|
temperature, |
|
|
top_p, |
|
|
max_tokens, |
|
|
] |
|
|
|
|
|
send_outputs = [ |
|
|
history_state, |
|
|
audio_output, |
|
|
user_text, |
|
|
image_input, |
|
|
audio_input, |
|
|
video_input, |
|
|
] |
|
|
|
|
|
|
|
|
send_btn.click( |
|
|
fn=qwen3_omni_inference, |
|
|
inputs=send_inputs, |
|
|
outputs=send_outputs, |
|
|
queue=True, |
|
|
).then( |
|
|
lambda h: h, |
|
|
inputs=history_state, |
|
|
outputs=chatbot, |
|
|
) |
|
|
|
|
|
|
|
|
user_text.submit( |
|
|
fn=qwen3_omni_inference, |
|
|
inputs=send_inputs, |
|
|
outputs=send_outputs, |
|
|
queue=True, |
|
|
).then( |
|
|
lambda h: h, |
|
|
inputs=history_state, |
|
|
outputs=chatbot, |
|
|
) |
|
|
|
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_chat, |
|
|
inputs=None, |
|
|
outputs=[history_state, audio_output], |
|
|
).then( |
|
|
lambda: [], |
|
|
inputs=None, |
|
|
outputs=chatbot, |
|
|
).then( |
|
|
lambda: ("", None, None, None), |
|
|
inputs=None, |
|
|
outputs=[user_text, image_input, audio_input, video_input], |
|
|
) |
|
|
|
|
|
|
|
|
demo.load( |
|
|
lambda: gr.Info("جاري تحميل النموذج... قد يستغرق هذا بضع دقائق في المرة الأولى."), |
|
|
inputs=None, |
|
|
outputs=None, |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
demo = create_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo.launch( |
|
|
ssr_mode=False, |
|
|
show_error=True, |
|
|
) |