Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| T5ForConditionalGeneration, | |
| T5Tokenizer, | |
| ) | |
| import time | |
| import hashlib | |
| from typing import List, Dict, Tuple, Optional | |
| import json | |
| import os | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| DEFAULT_MODEL = "Supra-50M-Instruct" | |
| TITLE_MODEL_ID = "SupraLabs/Supra-Title-Flan-85M" | |
| # Available models | |
| AVAILABLE_MODELS = { | |
| "Supra-50M-Instruct": { | |
| "id": "SupraLabs/Supra-50M-Instruct", | |
| "type": "instruct", | |
| "description": "50M parameter instruction-tuned model, suitable for general chat" | |
| }, | |
| "Supra-50M-Reasoning": { | |
| "id": "SupraLabs/Supra-50M-Reasoning", | |
| "type": "reasoning", | |
| "description": "50M reasoning model that outputs a thought process" | |
| }, | |
| "Supra-1.5-50M-Instruct-exp": { | |
| "id": "SupraLabs/Supra-1.5-50M-Instruct-exp", | |
| "type": "instruct", | |
| "description": "Experimental 50M instruct model with 5K context length" | |
| }, | |
| "Supra-50M-Base": { | |
| "id": "SupraLabs/Supra-50M-Base", | |
| "type": "base", | |
| "description": "50M base model, pure next‑token prediction" | |
| }, | |
| "StorySupra-10M": { | |
| "id": "SupraLabs/StorySupra-10M", | |
| "type": "base", | |
| "description": "10M story generation model" | |
| }, | |
| "Supra-Mini-v5-8M": { | |
| "id": "SupraLabs/Supra-Mini-v5-8M", | |
| "type": "base", | |
| "description": "8M ultra‑small model for fast experimentation" | |
| } | |
| } | |
| # ============================================================ | |
| # Model caching | |
| # ============================================================ | |
| _model_cache = {} | |
| _title_model = None | |
| _title_tokenizer = None | |
| # ============================================================ | |
| # Title generator (Supra-Title-Flan-85M) | |
| # ============================================================ | |
| def load_title_model(): | |
| """Load the title generation model.""" | |
| global _title_model, _title_tokenizer | |
| if _title_model is None: | |
| print(f"[*] Loading title model: {TITLE_MODEL_ID}") | |
| _title_tokenizer = T5Tokenizer.from_pretrained(TITLE_MODEL_ID) | |
| _title_model = T5ForConditionalGeneration.from_pretrained( | |
| TITLE_MODEL_ID, | |
| torch_dtype=torch.float32 | |
| ) | |
| _title_model.eval() | |
| return _title_model, _title_tokenizer | |
| def generate_chat_title(user_message: str, max_new_tokens: int = 32) -> str: | |
| """Generate a conversation title based on the first user message.""" | |
| try: | |
| model, tokenizer = load_title_model() | |
| prompt = f"generate title: {user_message.strip()}" | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| ) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=4, | |
| early_stopping=True, | |
| ) | |
| title = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if len(title) > 50: | |
| title = title[:47] + "..." | |
| return title.strip() or "New Conversation" | |
| except Exception as e: | |
| print(f"[!] Title generation failed: {e}") | |
| return "New Conversation" | |
| # ============================================================ | |
| # Conversation model loader | |
| # ============================================================ | |
| def load_model(model_key: str): | |
| """Load the specified conversation model.""" | |
| if model_key in _model_cache: | |
| return _model_cache[model_key] | |
| model_info = AVAILABLE_MODELS.get(model_key) | |
| if not model_info: | |
| raise ValueError(f"Unknown model: {model_key}") | |
| model_id = model_info["id"] | |
| model_type = model_info["type"] | |
| print(f"[*] Loading model: {model_id}") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| model.eval() | |
| _model_cache[model_key] = (model, tokenizer, model_type, device) | |
| return _model_cache[model_key] | |
| # ============================================================ | |
| # Prompt construction | |
| # ============================================================ | |
| def build_prompt(model_type: str, message: str, history: List[Tuple[str, str]]) -> str: | |
| """Construct the prompt according to the model type.""" | |
| # Build conversation history in a standard format | |
| conversation = "" | |
| for user_msg, bot_msg in history: | |
| conversation += f"User: {user_msg}\nAssistant: {bot_msg}\n" | |
| conversation += f"User: {message}\nAssistant:" | |
| if model_type == "reasoning": | |
| # For reasoning models, we add the thought trigger token. | |
| # The model will then generate <|begin_of_thought|> ... <|end_of_thought|> | |
| # followed by <|begin_of_solution|> ... <|end_of_solution|> | |
| return conversation + " <|begin_of_thought|>" | |
| else: | |
| return conversation | |
| # ============================================================ | |
| # Response generation | |
| # ============================================================ | |
| def generate_response( | |
| model_key: str, | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.1, | |
| ) -> str: | |
| """Generate a response from the selected model.""" | |
| try: | |
| model, tokenizer, model_type, device = load_model(model_key) | |
| prompt = build_prompt(model_type, message, history) | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048 if "1.5" in model_key else 1024, | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract the assistant's reply (remove the prompt) | |
| if prompt in full_text: | |
| response = full_text[len(prompt):].strip() | |
| else: | |
| # Fallback: split by "Assistant:" if present | |
| parts = full_text.split("Assistant:") | |
| response = parts[-1].strip() if len(parts) > 1 else full_text.strip() | |
| # For reasoning models, keep the whole thought+answer structure | |
| if model_type == "reasoning" and "<|begin_of_thought|>" in response: | |
| # We return everything after the prompt; the user will see the thought process. | |
| pass | |
| return response or "(Model did not produce a valid response)" | |
| except Exception as e: | |
| print(f"[!] Generation error: {e}") | |
| return f"Error: {str(e)}" | |
| # ============================================================ | |
| # Gradio Interface | |
| # ============================================================ | |
| def chat_interface( | |
| message: str, | |
| history: List[Dict], | |
| model_choice: str, | |
| temperature: float, | |
| max_tokens: int, | |
| ): | |
| """Gradio chat interface callback.""" | |
| if not message or not message.strip(): | |
| yield history, "" | |
| return | |
| # Convert history format | |
| formatted_history = [] | |
| for i in range(0, len(history), 2): | |
| if i + 1 < len(history): | |
| formatted_history.append((history[i]["content"], history[i+1]["content"])) | |
| response = generate_response( | |
| model_choice, | |
| message, | |
| formatted_history, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| ) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| yield history, "" | |
| def get_title_from_first_message(message: str) -> str: | |
| """Generate a title from the first user message.""" | |
| if message and message.strip(): | |
| return generate_chat_title(message) | |
| return "New Conversation" | |
| # ============================================================ | |
| # Create Gradio app | |
| # ============================================================ | |
| def create_app(): | |
| """Create and return the Gradio Blocks app.""" | |
| with gr.Blocks(title="SupraChat – SupraLabs Chat Interface") as demo: | |
| gr.Markdown(""" | |
| # 🤖 SupraChat | |
| Chat interface powered by SupraLabs' ultra‑small language models. | |
| Conversation history is stored in RAM and cleared when you leave the page. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| model_choice = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value=DEFAULT_MODEL, | |
| label="Select Model", | |
| info="Different models have different strengths", | |
| ) | |
| with gr.Column(scale=2): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative", | |
| ) | |
| with gr.Column(scale=2): | |
| max_tokens = gr.Slider( | |
| minimum=64, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Max New Tokens", | |
| info="Maximum length of the reply", | |
| ) | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| height=500, | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message here...", | |
| scale=9, | |
| container=False, | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| with gr.Row(): | |
| clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary", size="sm") | |
| title_display = gr.Textbox( | |
| label="Conversation Title", | |
| placeholder="Auto‑generated from the first message", | |
| interactive=False, | |
| scale=1, | |
| ) | |
| state = gr.State([]) | |
| # ============================================================ | |
| # Event handlers | |
| # ============================================================ | |
| def respond( | |
| message: str, | |
| history: List[Dict], | |
| model: str, | |
| temp: float, | |
| max_tok: int, | |
| ): | |
| if not message or not message.strip(): | |
| return history, "", history, "" | |
| # Generate title on first message | |
| title = "" | |
| if len(history) == 0: | |
| title = get_title_from_first_message(message) | |
| # Generate response | |
| formatted_history = [] | |
| for i in range(0, len(history), 2): | |
| if i + 1 < len(history): | |
| formatted_history.append((history[i]["content"], history[i+1]["content"])) | |
| response = generate_response( | |
| model, | |
| message, | |
| formatted_history, | |
| max_new_tokens=max_tok, | |
| temperature=temp, | |
| ) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| # If this was the first message, set title | |
| if len(history) == 2: | |
| title = get_title_from_first_message(message) | |
| return history, "", history, title | |
| def clear_chat(): | |
| return [], "", "New Conversation" | |
| # Send button | |
| send_btn.click( | |
| fn=respond, | |
| inputs=[msg, state, model_choice, temperature, max_tokens], | |
| outputs=[chatbot, msg, state, title_display], | |
| ) | |
| # Enter key | |
| msg.submit( | |
| fn=respond, | |
| inputs=[msg, state, model_choice, temperature, max_tokens], | |
| outputs=[chatbot, msg, state, title_display], | |
| ) | |
| # Clear | |
| clear_btn.click( | |
| fn=clear_chat, | |
| inputs=[], | |
| outputs=[chatbot, msg, title_display], | |
| ).then( | |
| lambda: [], | |
| outputs=[state] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### 📋 Model Overview | |
| | Model | Type | Description | | |
| |-------|------|-------------| | |
| | **Supra-50M-Instruct** | Instruct | General‑purpose chat, 50M parameters | | |
| | **Supra-50M-Reasoning** | Reasoning | Includes a thought process for complex tasks | | |
| | **Supra-1.5-50M-Instruct-exp** | Instruct | Experimental, 5K context window | | |
| | **Supra-50M-Base** | Base | Raw language modelling, no instruction tuning | | |
| | **StorySupra-10M** | Base | Specialised for story generation | | |
| | **Supra-Mini-v5-8M** | Base | Extremely small, fast responses | | |
| > 💡 **Note**: Conversation history is kept in memory only. It will be cleared when you reload or close the page. | |
| """) | |
| return demo | |
| # ============================================================ | |
| # Launch | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| demo = create_app() | |
| print("App created, setting up queue...") | |
| demo.queue(default_concurrency_limit=5) | |
| print("Queue set, launching...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| debug=True, | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="gray", | |
| neutral_hue="gray", | |
| ), | |
| css=""" | |
| .chatbot-container { | |
| max-width: 800px; | |
| margin: 0 auto; | |
| } | |
| .model-selector { | |
| margin-bottom: 10px; | |
| } | |
| .title-input { | |
| font-size: 1.2em; | |
| font-weight: bold; | |
| } | |
| """ | |
| ) |