import gradio as gr import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, T5ForConditionalGeneration, T5Tokenizer, ) import time import hashlib from typing import List, Dict, Tuple, Optional import json import os # ============================================================ # Configuration # ============================================================ DEFAULT_MODEL = "Supra-50M-Instruct" TITLE_MODEL_ID = "SupraLabs/Supra-Title-Flan-85M" # Available models AVAILABLE_MODELS = { "Supra-50M-Instruct": { "id": "SupraLabs/Supra-50M-Instruct", "type": "instruct", "description": "50M parameter instruction-tuned model, suitable for general chat" }, "Supra-50M-Reasoning": { "id": "SupraLabs/Supra-50M-Reasoning", "type": "reasoning", "description": "50M reasoning model that outputs a thought process" }, "Supra-1.5-50M-Instruct-exp": { "id": "SupraLabs/Supra-1.5-50M-Instruct-exp", "type": "instruct", "description": "Experimental 50M instruct model with 5K context length" }, "Supra-50M-Base": { "id": "SupraLabs/Supra-50M-Base", "type": "base", "description": "50M base model, pure next‑token prediction" }, "StorySupra-10M": { "id": "SupraLabs/StorySupra-10M", "type": "base", "description": "10M story generation model" }, "Supra-Mini-v5-8M": { "id": "SupraLabs/Supra-Mini-v5-8M", "type": "base", "description": "8M ultra‑small model for fast experimentation" } } # ============================================================ # Model caching # ============================================================ _model_cache = {} _title_model = None _title_tokenizer = None # ============================================================ # Title generator (Supra-Title-Flan-85M) # ============================================================ def load_title_model(): """Load the title generation model.""" global _title_model, _title_tokenizer if _title_model is None: print(f"[*] Loading title model: {TITLE_MODEL_ID}") _title_tokenizer = T5Tokenizer.from_pretrained(TITLE_MODEL_ID) _title_model = T5ForConditionalGeneration.from_pretrained( TITLE_MODEL_ID, torch_dtype=torch.float32 ) _title_model.eval() return _title_model, _title_tokenizer def generate_chat_title(user_message: str, max_new_tokens: int = 32) -> str: """Generate a conversation title based on the first user message.""" try: model, tokenizer = load_title_model() prompt = f"generate title: {user_message.strip()}" inputs = tokenizer( prompt, return_tensors="pt", max_length=512, truncation=True, ) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, num_beams=4, early_stopping=True, ) title = tokenizer.decode(outputs[0], skip_special_tokens=True) if len(title) > 50: title = title[:47] + "..." return title.strip() or "New Conversation" except Exception as e: print(f"[!] Title generation failed: {e}") return "New Conversation" # ============================================================ # Conversation model loader # ============================================================ def load_model(model_key: str): """Load the specified conversation model.""" if model_key in _model_cache: return _model_cache[model_key] model_info = AVAILABLE_MODELS.get(model_key) if not model_info: raise ValueError(f"Unknown model: {model_key}") model_id = model_info["id"] model_type = model_info["type"] print(f"[*] Loading model: {model_id}") device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch_dtype, device_map="auto" if torch.cuda.is_available() else None ) if not torch.cuda.is_available(): model = model.to(device) model.eval() _model_cache[model_key] = (model, tokenizer, model_type, device) return _model_cache[model_key] # ============================================================ # Prompt construction # ============================================================ def build_prompt(model_type: str, message: str, history: List[Tuple[str, str]]) -> str: """Construct the prompt according to the model type.""" # Build conversation history in a standard format conversation = "" for user_msg, bot_msg in history: conversation += f"User: {user_msg}\nAssistant: {bot_msg}\n" conversation += f"User: {message}\nAssistant:" if model_type == "reasoning": # For reasoning models, we add the thought trigger token. # The model will then generate <|begin_of_thought|> ... <|end_of_thought|> # followed by <|begin_of_solution|> ... <|end_of_solution|> return conversation + " <|begin_of_thought|>" else: return conversation # ============================================================ # Response generation # ============================================================ def generate_response( model_key: str, message: str, history: List[Tuple[str, str]], max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.1, ) -> str: """Generate a response from the selected model.""" try: model, tokenizer, model_type, device = load_model(model_key) prompt = build_prompt(model_type, message, history) inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 if "1.5" in model_key else 1024, ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract the assistant's reply (remove the prompt) if prompt in full_text: response = full_text[len(prompt):].strip() else: # Fallback: split by "Assistant:" if present parts = full_text.split("Assistant:") response = parts[-1].strip() if len(parts) > 1 else full_text.strip() # For reasoning models, keep the whole thought+answer structure if model_type == "reasoning" and "<|begin_of_thought|>" in response: # We return everything after the prompt; the user will see the thought process. pass return response or "(Model did not produce a valid response)" except Exception as e: print(f"[!] Generation error: {e}") return f"Error: {str(e)}" # ============================================================ # Gradio Interface # ============================================================ def chat_interface( message: str, history: List[Dict], model_choice: str, temperature: float, max_tokens: int, ): """Gradio chat interface callback.""" if not message or not message.strip(): yield history, "" return # Convert history format formatted_history = [] for i in range(0, len(history), 2): if i + 1 < len(history): formatted_history.append((history[i]["content"], history[i+1]["content"])) response = generate_response( model_choice, message, formatted_history, max_new_tokens=max_tokens, temperature=temperature, ) history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) yield history, "" def get_title_from_first_message(message: str) -> str: """Generate a title from the first user message.""" if message and message.strip(): return generate_chat_title(message) return "New Conversation" # ============================================================ # Create Gradio app # ============================================================ def create_app(): """Create and return the Gradio Blocks app.""" with gr.Blocks(title="SupraChat – SupraLabs Chat Interface") as demo: gr.Markdown(""" # 🤖 SupraChat Chat interface powered by SupraLabs' ultra‑small language models. Conversation history is stored in RAM and cleared when you leave the page. """) with gr.Row(): with gr.Column(scale=4): model_choice = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value=DEFAULT_MODEL, label="Select Model", info="Different models have different strengths", ) with gr.Column(scale=2): temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature", info="Higher = more creative", ) with gr.Column(scale=2): max_tokens = gr.Slider( minimum=64, maximum=1024, value=512, step=64, label="Max New Tokens", info="Maximum length of the reply", ) chatbot = gr.Chatbot( label="Conversation", height=500, ) with gr.Row(): msg = gr.Textbox( label="Message", placeholder="Type your message here...", scale=9, container=False, ) send_btn = gr.Button("Send", scale=1, variant="primary") with gr.Row(): clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary", size="sm") title_display = gr.Textbox( label="Conversation Title", placeholder="Auto‑generated from the first message", interactive=False, scale=1, ) state = gr.State([]) # ============================================================ # Event handlers # ============================================================ def respond( message: str, history: List[Dict], model: str, temp: float, max_tok: int, ): if not message or not message.strip(): return history, "", history, "" # Generate title on first message title = "" if len(history) == 0: title = get_title_from_first_message(message) # Generate response formatted_history = [] for i in range(0, len(history), 2): if i + 1 < len(history): formatted_history.append((history[i]["content"], history[i+1]["content"])) response = generate_response( model, message, formatted_history, max_new_tokens=max_tok, temperature=temp, ) history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) # If this was the first message, set title if len(history) == 2: title = get_title_from_first_message(message) return history, "", history, title def clear_chat(): return [], "", "New Conversation" # Send button send_btn.click( fn=respond, inputs=[msg, state, model_choice, temperature, max_tokens], outputs=[chatbot, msg, state, title_display], ) # Enter key msg.submit( fn=respond, inputs=[msg, state, model_choice, temperature, max_tokens], outputs=[chatbot, msg, state, title_display], ) # Clear clear_btn.click( fn=clear_chat, inputs=[], outputs=[chatbot, msg, title_display], ).then( lambda: [], outputs=[state] ) gr.Markdown(""" --- ### 📋 Model Overview | Model | Type | Description | |-------|------|-------------| | **Supra-50M-Instruct** | Instruct | General‑purpose chat, 50M parameters | | **Supra-50M-Reasoning** | Reasoning | Includes a thought process for complex tasks | | **Supra-1.5-50M-Instruct-exp** | Instruct | Experimental, 5K context window | | **Supra-50M-Base** | Base | Raw language modelling, no instruction tuning | | **StorySupra-10M** | Base | Specialised for story generation | | **Supra-Mini-v5-8M** | Base | Extremely small, fast responses | > 💡 **Note**: Conversation history is kept in memory only. It will be cleared when you reload or close the page. """) return demo # ============================================================ # Launch # ============================================================ if __name__ == "__main__": demo = create_app() print("App created, setting up queue...") demo.queue(default_concurrency_limit=5) print("Queue set, launching...") demo.launch( server_name="0.0.0.0", server_port=7860, debug=True, theme=gr.themes.Soft( primary_hue="blue", secondary_hue="gray", neutral_hue="gray", ), css=""" .chatbot-container { max-width: 800px; margin: 0 auto; } .model-selector { margin-bottom: 10px; } .title-input { font-size: 1.2em; font-weight: bold; } """ )