LH-Tech-AI's picture
Update app.py
fcb519e verified
Raw
History Blame Contribute Delete
15 kB
import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
T5ForConditionalGeneration,
T5Tokenizer,
)
import time
import hashlib
from typing import List, Dict, Tuple, Optional
import json
import os
# ============================================================
# Configuration
# ============================================================
DEFAULT_MODEL = "Supra-50M-Instruct"
TITLE_MODEL_ID = "SupraLabs/Supra-Title-Flan-85M"
# Available models
AVAILABLE_MODELS = {
"Supra-50M-Instruct": {
"id": "SupraLabs/Supra-50M-Instruct",
"type": "instruct",
"description": "50M parameter instruction-tuned model, suitable for general chat"
},
"Supra-50M-Reasoning": {
"id": "SupraLabs/Supra-50M-Reasoning",
"type": "reasoning",
"description": "50M reasoning model that outputs a thought process"
},
"Supra-1.5-50M-Instruct-exp": {
"id": "SupraLabs/Supra-1.5-50M-Instruct-exp",
"type": "instruct",
"description": "Experimental 50M instruct model with 5K context length"
},
"Supra-50M-Base": {
"id": "SupraLabs/Supra-50M-Base",
"type": "base",
"description": "50M base model, pure next‑token prediction"
},
"StorySupra-10M": {
"id": "SupraLabs/StorySupra-10M",
"type": "base",
"description": "10M story generation model"
},
"Supra-Mini-v5-8M": {
"id": "SupraLabs/Supra-Mini-v5-8M",
"type": "base",
"description": "8M ultra‑small model for fast experimentation"
}
}
# ============================================================
# Model caching
# ============================================================
_model_cache = {}
_title_model = None
_title_tokenizer = None
# ============================================================
# Title generator (Supra-Title-Flan-85M)
# ============================================================
def load_title_model():
"""Load the title generation model."""
global _title_model, _title_tokenizer
if _title_model is None:
print(f"[*] Loading title model: {TITLE_MODEL_ID}")
_title_tokenizer = T5Tokenizer.from_pretrained(TITLE_MODEL_ID)
_title_model = T5ForConditionalGeneration.from_pretrained(
TITLE_MODEL_ID,
torch_dtype=torch.float32
)
_title_model.eval()
return _title_model, _title_tokenizer
def generate_chat_title(user_message: str, max_new_tokens: int = 32) -> str:
"""Generate a conversation title based on the first user message."""
try:
model, tokenizer = load_title_model()
prompt = f"generate title: {user_message.strip()}"
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=4,
early_stopping=True,
)
title = tokenizer.decode(outputs[0], skip_special_tokens=True)
if len(title) > 50:
title = title[:47] + "..."
return title.strip() or "New Conversation"
except Exception as e:
print(f"[!] Title generation failed: {e}")
return "New Conversation"
# ============================================================
# Conversation model loader
# ============================================================
def load_model(model_key: str):
"""Load the specified conversation model."""
if model_key in _model_cache:
return _model_cache[model_key]
model_info = AVAILABLE_MODELS.get(model_key)
if not model_info:
raise ValueError(f"Unknown model: {model_key}")
model_id = model_info["id"]
model_type = model_info["type"]
print(f"[*] Loading model: {model_id}")
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch_dtype,
device_map="auto" if torch.cuda.is_available() else None
)
if not torch.cuda.is_available():
model = model.to(device)
model.eval()
_model_cache[model_key] = (model, tokenizer, model_type, device)
return _model_cache[model_key]
# ============================================================
# Prompt construction
# ============================================================
def build_prompt(model_type: str, message: str, history: List[Tuple[str, str]]) -> str:
"""Construct the prompt according to the model type."""
# Build conversation history in a standard format
conversation = ""
for user_msg, bot_msg in history:
conversation += f"User: {user_msg}\nAssistant: {bot_msg}\n"
conversation += f"User: {message}\nAssistant:"
if model_type == "reasoning":
# For reasoning models, we add the thought trigger token.
# The model will then generate <|begin_of_thought|> ... <|end_of_thought|>
# followed by <|begin_of_solution|> ... <|end_of_solution|>
return conversation + " <|begin_of_thought|>"
else:
return conversation
# ============================================================
# Response generation
# ============================================================
def generate_response(
model_key: str,
message: str,
history: List[Tuple[str, str]],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
) -> str:
"""Generate a response from the selected model."""
try:
model, tokenizer, model_type, device = load_model(model_key)
prompt = build_prompt(model_type, message, history)
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048 if "1.5" in model_key else 1024,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the assistant's reply (remove the prompt)
if prompt in full_text:
response = full_text[len(prompt):].strip()
else:
# Fallback: split by "Assistant:" if present
parts = full_text.split("Assistant:")
response = parts[-1].strip() if len(parts) > 1 else full_text.strip()
# For reasoning models, keep the whole thought+answer structure
if model_type == "reasoning" and "<|begin_of_thought|>" in response:
# We return everything after the prompt; the user will see the thought process.
pass
return response or "(Model did not produce a valid response)"
except Exception as e:
print(f"[!] Generation error: {e}")
return f"Error: {str(e)}"
# ============================================================
# Gradio Interface
# ============================================================
def chat_interface(
message: str,
history: List[Dict],
model_choice: str,
temperature: float,
max_tokens: int,
):
"""Gradio chat interface callback."""
if not message or not message.strip():
yield history, ""
return
# Convert history format
formatted_history = []
for i in range(0, len(history), 2):
if i + 1 < len(history):
formatted_history.append((history[i]["content"], history[i+1]["content"]))
response = generate_response(
model_choice,
message,
formatted_history,
max_new_tokens=max_tokens,
temperature=temperature,
)
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
yield history, ""
def get_title_from_first_message(message: str) -> str:
"""Generate a title from the first user message."""
if message and message.strip():
return generate_chat_title(message)
return "New Conversation"
# ============================================================
# Create Gradio app
# ============================================================
def create_app():
"""Create and return the Gradio Blocks app."""
with gr.Blocks(title="SupraChat – SupraLabs Chat Interface") as demo:
gr.Markdown("""
# 🤖 SupraChat
Chat interface powered by SupraLabs' ultra‑small language models.
Conversation history is stored in RAM and cleared when you leave the page.
""")
with gr.Row():
with gr.Column(scale=4):
model_choice = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value=DEFAULT_MODEL,
label="Select Model",
info="Different models have different strengths",
)
with gr.Column(scale=2):
temperature = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative",
)
with gr.Column(scale=2):
max_tokens = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Max New Tokens",
info="Maximum length of the reply",
)
chatbot = gr.Chatbot(
label="Conversation",
height=500,
)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Type your message here...",
scale=9,
container=False,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
with gr.Row():
clear_btn = gr.Button("🗑️ Clear Chat", variant="secondary", size="sm")
title_display = gr.Textbox(
label="Conversation Title",
placeholder="Auto‑generated from the first message",
interactive=False,
scale=1,
)
state = gr.State([])
# ============================================================
# Event handlers
# ============================================================
def respond(
message: str,
history: List[Dict],
model: str,
temp: float,
max_tok: int,
):
if not message or not message.strip():
return history, "", history, ""
# Generate title on first message
title = ""
if len(history) == 0:
title = get_title_from_first_message(message)
# Generate response
formatted_history = []
for i in range(0, len(history), 2):
if i + 1 < len(history):
formatted_history.append((history[i]["content"], history[i+1]["content"]))
response = generate_response(
model,
message,
formatted_history,
max_new_tokens=max_tok,
temperature=temp,
)
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
# If this was the first message, set title
if len(history) == 2:
title = get_title_from_first_message(message)
return history, "", history, title
def clear_chat():
return [], "", "New Conversation"
# Send button
send_btn.click(
fn=respond,
inputs=[msg, state, model_choice, temperature, max_tokens],
outputs=[chatbot, msg, state, title_display],
)
# Enter key
msg.submit(
fn=respond,
inputs=[msg, state, model_choice, temperature, max_tokens],
outputs=[chatbot, msg, state, title_display],
)
# Clear
clear_btn.click(
fn=clear_chat,
inputs=[],
outputs=[chatbot, msg, title_display],
).then(
lambda: [],
outputs=[state]
)
gr.Markdown("""
---
### 📋 Model Overview
| Model | Type | Description |
|-------|------|-------------|
| **Supra-50M-Instruct** | Instruct | General‑purpose chat, 50M parameters |
| **Supra-50M-Reasoning** | Reasoning | Includes a thought process for complex tasks |
| **Supra-1.5-50M-Instruct-exp** | Instruct | Experimental, 5K context window |
| **Supra-50M-Base** | Base | Raw language modelling, no instruction tuning |
| **StorySupra-10M** | Base | Specialised for story generation |
| **Supra-Mini-v5-8M** | Base | Extremely small, fast responses |
> 💡 **Note**: Conversation history is kept in memory only. It will be cleared when you reload or close the page.
""")
return demo
# ============================================================
# Launch
# ============================================================
if __name__ == "__main__":
demo = create_app()
print("App created, setting up queue...")
demo.queue(default_concurrency_limit=5)
print("Queue set, launching...")
demo.launch(
server_name="0.0.0.0",
server_port=7860,
debug=True,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="gray",
),
css="""
.chatbot-container {
max-width: 800px;
margin: 0 auto;
}
.model-selector {
margin-bottom: 10px;
}
.title-input {
font-size: 1.2em;
font-weight: bold;
}
"""
)