import os import time from typing import List, Dict, Tuple, Any import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login import spaces # ========================= # Configuration # ========================= MODEL_ID = "facebook/MobileLLM-Pro" MODEL_SUBFOLDER = "instruct" # "base" | "instruct" MAX_HISTORY_LENGTH = 10 MAX_NEW_TOKENS = 512 DEFAULT_SYSTEM_PROMPT = ( "You are a helpful, friendly, and intelligent assistant. " "Provide clear, accurate, and thoughtful responses." ) # ========================= # HF Login (optional) # ========================= HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: try: login(token=HF_TOKEN) print("Successfully logged in to Hugging Face") except Exception as e: print(f"Warning: Could not login to Hugging Face: {e}") # ========================= # Utilities # ========================= def tuples_from_messages(messages: List[Dict[str, Any]]) -> List[List[str]]: """ Convert a Chatbot(type='messages') style history into tuples format [[user, assistant], ...]. If already tuples-like, return as-is. """ if not messages: return [] # If already tuples-like (list with elements of length 2), return if isinstance(messages[0], (list, tuple)) and len(messages[0]) == 2: return [list(x) for x in messages] # Otherwise, convert from [{"role": "...", "content": "..."}, ...] pairs: List[List[str]] = [] last_user: str | None = None for m in messages: role = m.get("role") content = m.get("content", "") if role == "user": last_user = content elif role == "assistant": if last_user is None: # If assistant appears first (odd state), pair with empty user pairs.append(["", content]) else: pairs.append([last_user, content]) last_user = None # If there's a dangling user without assistant, pair with empty string if last_user is not None: pairs.append([last_user, ""]) return pairs def messages_from_tuples(history_tuples: List[List[str]]) -> List[Dict[str, str]]: """ Convert tuples [[user, assistant], ...] into list of role dictionaries: [{"role": "user", ...}, {"role": "assistant", ...}, ...] """ messages: List[Dict[str, str]] = [] for u, a in history_tuples: messages.append({"role": "user", "content": u}) if a: messages.append({"role": "assistant", "content": a}) return messages # ========================= # Chat Model Wrapper # ========================= class MobileLLMChat: def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False self.load_model(version=MODEL_SUBFOLDER) def load_model(self, version="instruct"): """Load the MobileLLM-Pro model and tokenizer (initially to CPU).""" try: print(f"Loading {MODEL_ID} ({version})...") self.tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, subfolder=version ) self.model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, subfolder=version, torch_dtype=torch.float16, low_cpu_mem_usage=True, ) # Safety: ensure pad token exists (some LLMs don't set it) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.model.eval() self.model_loaded = True print("Model loaded successfully to system memory (CPU).") return True except Exception as e: print(f"Error loading model: {e}") return False def format_chat_history( self, history: List[Dict[str, str]], system_prompt: str ) -> List[Dict[str, str]]: """Format chat history for tokenizer's chat template.""" messages = [{"role": "system", "content": system_prompt}] # Truncate to keep the last N turns trimmed = [] for msg in history: if msg["role"] in ("user", "assistant"): trimmed.append(msg) if MAX_HISTORY_LENGTH > 0: trimmed = trimmed[-(MAX_HISTORY_LENGTH * 2) :] messages.extend(trimmed) return messages @spaces.GPU(duration=120) def generate_response( self, user_input: str, history: List[Dict[str, str]], system_prompt: str, temperature: float = 0.7, max_new_tokens: int = MAX_NEW_TOKENS, ) -> str: """Generate a full response (GPU during inference).""" if not self.model_loaded: return "Model not loaded. Please try reloading the space." try: # Choose device (Spaces GPU if available) use_cuda = torch.cuda.is_available() self.device = torch.device("cuda" if use_cuda else "cpu") self.model.to(self.device) # Append the new user message history.append({"role": "user", "content": user_input}) messages = self.format_chat_history(history, system_prompt) # Build inputs with chat template input_ids = self.tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ).to(self.device) # No padding used here -> full ones mask attention_mask = torch.ones_like(input_ids) with torch.no_grad(): outputs = self.model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Slice only the newly generated tokens gen_ids = outputs[0][input_ids.shape[1] :] response = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() # Update history (internal state for the caller if desired) history.append({"role": "assistant", "content": response}) # Free GPU VRAM if use_cuda: self.model.to("cpu") torch.cuda.empty_cache() return response except Exception as e: return f"Error generating response: {str(e)}" # ========================= # Initialize Chat Model # ========================= print("Initializing MobileLLM-Pro model...") chat_model = MobileLLMChat() # ========================= # Gradio Helpers # ========================= def clear_chat(): """Clear the chat history and input box.""" return [], "" def chat_fn(message, history, system_prompt, temperature): """Non-streaming chat handler (returns tuples).""" # DEFENSIVE: ensure tuples format history = tuples_from_messages(history) if not chat_model.model_loaded: return history + [[message, "Please wait for the model to load or reload the space."]] # Convert tuples -> role dicts for the model formatted_history = messages_from_tuples(history) # Generate full response once response = chat_model.generate_response(message, formatted_history, system_prompt, temperature) # Return updated tuples history return history + [[message, response]] def chat_stream_fn(message, history, system_prompt, temperature): """Streaming chat handler (tuples): generate once, then chunk out.""" # DEFENSIVE: ensure tuples format history = tuples_from_messages(history) if not chat_model.model_loaded: yield history + [[message, "Please wait for the model to load or reload the space."]] return # Convert tuples -> role dicts for the model formatted_history = messages_from_tuples(history) # Generate full response (GPU) full_response = chat_model.generate_response( message, formatted_history, system_prompt, temperature ) # Start new row and progressively fill assistant side base = history + [[message, ""]] if not isinstance(full_response, str): full_response = str(full_response) step = max(8, len(full_response) // 40) # ~40 chunks for i in range(0, len(full_response), step): partial = full_response[: i + step] yield base[:-1] + [[message, partial]] # Final ensure complete yield base[:-1] + [[message, full_response]] def handle_chat(message, history, system_prompt, temperature, streaming): return ( chat_stream_fn(message, history, system_prompt, temperature) if streaming else chat_fn(message, history, system_prompt, temperature) ) # ========================= # Gradio UI # ========================= with gr.Blocks( title="MobileLLM-Pro Chat", theme=gr.themes.Soft(), css=""" .gradio-container { max-width: 900px !important; margin: auto !important; } .message { padding: 12px !important; border-radius: 8px !important; margin-bottom: 8px !important; } .user-message { background-color: #e3f2fd !important; margin-left: 20% !important; } .assistant-message { background-color: #f5f5f5 !important; margin-right: 20% !important; } """ ) as demo: # Header gr.HTML( """

🤖 MobileLLM-Pro Chat

Built with anycoder

Chat with Facebook's MobileLLM-Pro model optimized for on-device inference

""" ) # Model status with gr.Row(): model_status = gr.Textbox( label="Model Status", value="Model loaded and ready!" if chat_model.model_loaded else "Model loading...", interactive=False, container=True, ) # Config with gr.Accordion("⚙️ Configuration", open=False): with gr.Row(): system_prompt = gr.Textbox( value=DEFAULT_SYSTEM_PROMPT, label="System Prompt", lines=3, info="Customize the AI's behavior and personality", ) with gr.Row(): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Controls randomness (higher = more creative)", ) streaming = gr.Checkbox( value=True, label="Enable Streaming", info="Show responses as they're being generated", ) # Chatbot in TUPLES mode (explicit) chatbot = gr.Chatbot( type="tuples", label="Chat History", height=500, show_copy_button=True, ) with gr.Row(): msg = gr.Textbox( label="Your Message", placeholder="Type your message here...", scale=4, container=False, ) submit_btn = gr.Button("Send", variant="primary", scale=1) clear_btn = gr.Button("Clear", scale=0) # Wire events (also clear the input box after send) msg.submit( handle_chat, inputs=[msg, chatbot, system_prompt, temperature, streaming], outputs=[chatbot], ).then(lambda: "", None, msg) submit_btn.click( handle_chat, inputs=[msg, chatbot, system_prompt, temperature, streaming], outputs=[chatbot], ).then(lambda: "", None, msg) clear_btn.click( clear_chat, outputs=[chatbot, msg], ) # Examples gr.Examples( examples=[ ["What are the benefits of on-device AI models?"], ["Explain quantum computing in simple terms."], ["Write a short poem about technology."], ["What's the difference between machine learning and deep learning?"], ["How can I improve my productivity?"], ], inputs=[msg], label="Example Prompts", ) # Footer gr.HTML( """

⚠️ Note: Model is pre-loaded for faster inference. GPU is allocated only during generation.

Model: facebook/MobileLLM-Pro

""" ) # Optional: queue to improve streaming UX demo.queue() # Launch (NO share=True on Spaces) if __name__ == "__main__": demo.launch( show_error=True, debug=True, )