| | import gradio as gr |
| | import requests |
| | from PIL import Image |
| | import torch |
| | from transformers import AutoModel, AutoProcessor |
| | import spaces |
| |
|
| | model_path = "YannQi/R-4B" |
| |
|
| | model = AutoModel.from_pretrained( |
| | model_path, |
| | torch_dtype=torch.float32, |
| | trust_remote_code=True, |
| | ).to("cuda") |
| |
|
| | processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
| |
|
| | @spaces.GPU(duration=120) |
| | def generate_response(message, history, thinking_mode): |
| | if not message: |
| | return "", history |
| |
|
| | messages = [] |
| | all_images = [] |
| |
|
| | for user_msg, asst_msg in history: |
| | |
| | if isinstance(user_msg, str): |
| | user_content = [{"type": "text", "text": user_msg}] |
| | else: |
| | text = user_msg.get('text', '') |
| | files = user_msg.get('files', []) |
| | file_paths = [f.get('path', str(f)) for f in files] |
| | user_content = [] |
| | img_paths = file_paths if isinstance(file_paths, list) else [] |
| | for path in img_paths: |
| | try: |
| | img = Image.open(path) |
| | all_images.append(img) |
| | user_content.append({"type": "image", "image": path}) |
| | except: |
| | pass |
| | if text: |
| | user_content.append({"type": "text", "text": text}) |
| | messages.append({"role": "user", "content": user_content}) |
| |
|
| | |
| | asst_text = asst_msg if isinstance(asst_msg, str) else asst_msg.get('text', '') |
| | messages.append({"role": "assistant", "content": [{"type": "text", "text": asst_text}]}) |
| |
|
| | |
| | if isinstance(message, str): |
| | curr_text = message |
| | curr_files = [] |
| | else: |
| | curr_text = message.get('text', '') |
| | curr_files = message.get('files', []) |
| | curr_user_content = [] |
| | curr_images = [] |
| | curr_file_paths = [f.get('path', str(f)) for f in curr_files] |
| | for path in curr_file_paths: |
| | if path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): |
| | try: |
| | img = Image.open(path) |
| | curr_images.append(img) |
| | curr_user_content.append({"type": "image", "image": path}) |
| | except: |
| | pass |
| | if curr_text: |
| | curr_user_content.append({"type": "text", "text": curr_text}) |
| | if not curr_user_content: |
| | return "", history |
| | messages.append({"role": "user", "content": curr_user_content}) |
| |
|
| | |
| | text = processor.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | thinking_mode=thinking_mode |
| | ) |
| |
|
| | |
| | all_images += curr_images |
| |
|
| | |
| | inputs = processor( |
| | images=all_images if all_images else None, |
| | text=text, |
| | return_tensors="pt" |
| | ).to("cuda") |
| |
|
| | |
| | with torch.no_grad(): |
| | generated_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=512, |
| | do_sample=True, |
| | temperature=0.7 |
| | ) |
| | output_ids = generated_ids[0][len(inputs.input_ids[0]):] |
| | output_text = processor.decode( |
| | output_ids, |
| | skip_special_tokens=True, |
| | clean_up_tokenization_spaces=False |
| | ) |
| |
|
| | |
| | user_display = message |
| | new_history = history + [(user_display, output_text)] |
| |
|
| | return "", new_history |
| |
|
| | with gr.Blocks(title="Transformers Chat") as demo: |
| | gr.Markdown("# Using 🤗 Transformers to Chat") |
| | gr.Markdown("Select thinking mode: auto (auto-thinking), long (thinking), short (non-thinking). Default: auto.") |
| | chatbot = gr.Chatbot(type="tuples", height=500, label="Chat") |
| | with gr.Row(): |
| | msg = gr.MultimodalTextbox( |
| | placeholder="Type your message or upload images...", |
| | file_types=[".jpg", ".jpeg", ".png", ".gif", ".bmp"], |
| | file_count="multiple", |
| | label="Message" |
| | ) |
| | mode = gr.Dropdown( |
| | choices=["auto", "long", "short"], |
| | value="auto", |
| | label="Thinking Mode", |
| | interactive=True |
| | ) |
| | with gr.Row(): |
| | submit_btn = gr.Button("Send", variant="primary", scale=3) |
| | clear_btn = gr.Button("Clear", scale=1) |
| | submit_btn.click(generate_response, [msg, chatbot, mode], [msg, chatbot]) |
| | msg.submit(generate_response, [msg, chatbot, mode], [msg, chatbot]) |
| | clear_btn.click(lambda: ([], ""), None, [chatbot, msg], queue=False) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |