Spaces:
Paused
Paused
| import os | |
| import time | |
| import random | |
| import gradio as gui | |
| from gradio.themes.utils import colors | |
| from dataclasses import dataclass | |
| from typing import Dict, Iterator, List, Literal, Optional, TypedDict, NotRequired | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| import torch | |
| # Custom theme for the Gradio interface | |
| custom_theme = gui.themes.Default( | |
| primary_hue=colors.blue, | |
| secondary_hue=colors.green, | |
| neutral_hue=colors.gray, | |
| font=[gui.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ).set( | |
| body_background_fill="#FFFFFF", | |
| body_text_color="#1F2937", | |
| button_primary_background_fill="#2D7FF9", | |
| button_primary_background_fill_hover="#1A56F0", | |
| button_secondary_background_fill="#10B981", | |
| button_secondary_background_fill_hover="#059669", | |
| block_title_text_color="#6B7280", | |
| block_label_text_color="#6B7280", | |
| background_fill_primary="#F9FAFB", | |
| background_fill_secondary="#F3F4F6", | |
| ) | |
| class UserMessage: | |
| content: str | |
| role: Literal["user", "assistant"] | |
| metadata: Optional[Dict] = None | |
| options: Optional[List[Dict]] = None | |
| class Metadata(TypedDict): | |
| title: NotRequired[str] | |
| id: NotRequired[int | str] | |
| parent_id: NotRequired[int | str] | |
| log: NotRequired[str] | |
| duration: NotRequired[float] | |
| status: NotRequired[Literal["pending", "done"]] | |
| MODEL_IDENTIFIER = "smol-ai/SmolLM2-135M-Instruct" | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_IDENTIFIER) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_IDENTIFIER, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| return model, tokenizer | |
| print("Loading model and tokenizer...") | |
| model_instance, tokenizer_instance = load_model() | |
| print("Model and tokenizer loaded!") | |
| def build_conversation_prompt(current_message: str, history: List[UserMessage]) -> str: | |
| conversation_history = [ | |
| f"{message.role.upper()}: {message.content}" for message in history | |
| ] | |
| conversation_history.append(f"USER: {current_message}") | |
| conversation_history.append("ASSISTANT: ") | |
| return "\n".join(conversation_history) | |
| def stream_chat_response(user_input: str, history: List[UserMessage]) -> Iterator[List[UserMessage]]: | |
| prompt_text = build_conversation_prompt(user_input, history) | |
| inputs = tokenizer_instance(prompt_text, return_tensors="pt").to(model_instance.device) | |
| response_streamer = TextIteratorStreamer( | |
| tokenizer_instance, | |
| timeout=10.0, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_params = { | |
| "input_ids": inputs.input_ids, | |
| "attention_mask": inputs.attention_mask, | |
| "max_new_tokens": 512, | |
| "temperature": 0.7, | |
| "top_p": 0.9, | |
| "streamer": response_streamer, | |
| "do_sample": True, | |
| } | |
| thread = Thread(target=model_instance.generate, kwargs=generation_params) | |
| thread.start() | |
| thought_buffer = "" | |
| updated_history = history + [UserMessage(role="user", content=user_input)] | |
| updated_history.append(create_thinking_message()) | |
| yield updated_history | |
| for _ in range(random.randint(3, 6)): | |
| thought_buffer = update_thoughts(thought_buffer, updated_history) | |
| yield updated_history | |
| time.sleep(0.5) | |
| finalize_thinking(updated_history, thought_buffer) | |
| yield updated_history | |
| for text_chunk in response_streamer: | |
| updated_history[-1] = UserMessage(role="assistant", content=updated_history[-1].content + text_chunk) | |
| yield updated_history | |
| time.sleep(0.01) | |
| def create_thinking_message() -> UserMessage: | |
| return UserMessage( | |
| role="assistant", | |
| content="", | |
| metadata={ | |
| "title": "🧠 Thinking...", | |
| "status": "pending" | |
| } | |
| ) | |
| def update_thoughts(thought_buffer: str, updated_history: List[UserMessage]) -> str: | |
| thought_segments = [ | |
| "Analyzing the user's query...", | |
| "Retrieving relevant information...", | |
| "Considering different perspectives...", | |
| "Formulating a coherent response...", | |
| "Checking for accuracy and completeness...", | |
| "Organizing thoughts in a logical structure..." | |
| ] | |
| thought_buffer += random.choice(thought_segments) + " " | |
| updated_history[-1] = UserMessage( | |
| role="assistant", | |
| content=thought_buffer, | |
| metadata={ | |
| "title": "🧠 Thinking...", | |
| "status": "pending" | |
| } | |
| ) | |
| return thought_buffer | |
| def finalize_thinking(updated_history: List[UserMessage], thought_buffer: str): | |
| thinking_duration = time.time() - start_time | |
| updated_history[-1] = UserMessage( | |
| role="assistant", | |
| content=thought_buffer, | |
| metadata={ | |
| "title": "🧠 Thinking Process", | |
| "status": "done", | |
| "duration": round(thinking_duration, 2) | |
| } | |
| ) | |
| updated_history.append(UserMessage(role="assistant", content="")) | |
| def reset_chat() -> List[UserMessage]: | |
| return [] | |
| style_sheet = """ | |
| .message-user { | |
| background-color: #F3F4F6 !important; | |
| border-radius: 10px; | |
| padding: 10px; | |
| margin: 8px 0; | |
| } | |
| .message-assistant { | |
| background-color: #F9FAFB !important; | |
| border-radius: 10px; | |
| padding: 10px; | |
| margin: 8px 0; | |
| border-left: 3px solid #2D7FF9; | |
| } | |
| .thinking-box { | |
| background-color: #F0F9FF !important; | |
| border: 1px solid #BAE6FD; | |
| border-radius: 6px; | |
| } | |
| .chat-container { | |
| height: calc(100vh - 230px); | |
| overflow-y: auto; | |
| padding: 16px; | |
| } | |
| .input-container { | |
| position: sticky; | |
| bottom: 0; | |
| background-color: #FFFFFF; | |
| padding: 16px; | |
| border-top: 1px solid #E5E7EB; | |
| } | |
| @media (max-width: 640px) { | |
| .chat-container { | |
| height: calc(100vh - 200px); | |
| } | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| """ | |
| with gui.Blocks(theme=custom_theme, css=style_sheet) as demo_interface: | |
| gui.HTML(""" | |
| <div style="text-align: center; margin-bottom: 1rem"> | |
| <h1 style="font-size: 2.5rem; font-weight: 600; color: #1F2937">SmolLM2 Chat</h1> | |
| <p style="font-size: 1.1rem; color: #6B7280"> | |
| Chat with SmolLM2-135M-Instruct: A small but capable AI assistant | |
| </p> | |
| </div> | |
| """) | |
| chat_interface = gui.Chatbot( | |
| value=[], | |
| avatar_images=(None, "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot.png"), | |
| show_label=False, | |
| container=True, | |
| height=600, | |
| elem_classes="chat-container", | |
| type="messages" | |
| ) | |
| with gui.Row(elem_classes="input-container"): | |
| with gui.Column(scale=20): | |
| message_input = gui.Textbox( | |
| show_label=False, | |
| placeholder="Type your message here...", | |
| container=False, | |
| lines=2 | |
| ) | |
| with gui.Column(scale=1, min_width=50): | |
| send_button = gui.Button("Send", variant="primary") | |
| with gui.Row(): | |
| clear_button = gui.Button("Clear Chat", variant="secondary") | |
| message_input.submit( | |
| stream_chat_response, | |
| [message_input, chat_interface], | |
| [chat_interface], | |
| queue=True | |
| ).then( | |
| lambda: "", | |
| None, | |
| [message_input], | |
| queue=False | |
| ) | |
| send_button.click( | |
| stream_chat_response, | |
| [message_input, chat_interface], | |
| [chat_interface], | |
| queue=True | |
| ).then( | |
| lambda: "", | |
| None, | |
| [message_input], | |
| queue=False | |
| ) | |
| clear_button.click( | |
| reset_chat, | |
| None, | |
| [chat_interface], | |
| queue=False | |
| ) | |
| message_input.submit(lambda: "", None, [message_input]) | |
| if __name__ == "__main__": | |
| demo_interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=5000, | |
| share=False | |
| ) |