Antal van den Bosch
Remove packages.txt entirely - not needed for HTTP client
d0ff18b unverified
"""
OLIFANT Text Generation Demo - External Server Version
Connects to an external TiMBL server for inference, keeping the HF Space lean.
"""
import gradio as gr
from timbl import TimblServer
from transformers import AutoTokenizer
import os
import numpy as np
from typing import List, Tuple, Dict
# Configuration - Set these via environment variables in HF Space settings
TIMBL_SERVER_HOST = os.environ.get("TIMBL_SERVER_HOST", "localhost")
TIMBL_SERVER_PORT = int(os.environ.get("TIMBL_SERVER_PORT", "7000"))
TOKENIZER_NAME = "gpt2"
CONTEXT_SIZE = int(os.environ.get("CONTEXT_SIZE", "4")) # Must match server model
class OlifantServerGenerator:
"""Wrapper for OLIFANT that generates text via external TiMBL server"""
def __init__(self, host: str, port: int, tokenizer_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.host = host
self.port = port
self.server = None
self.model_loaded = False
self._connect()
def _connect(self):
"""Establish connection to TiMBL server"""
try:
self.server = TimblServer(self.host, self.port)
self.model_loaded = True
print(f"Connected to TiMBL server at {self.host}:{self.port}")
except Exception as e:
print(f"Warning: Could not connect to TiMBL server: {e}")
self.model_loaded = False
def _ensure_connection(self):
"""Reconnect if connection was lost"""
if not self.model_loaded or self.server is None:
self._connect()
def tokenize_context(self, text: str) -> List[str]:
"""Tokenize text with GPT-2 and return last N token strings"""
token_strings = self.tokenizer.tokenize(text)
return token_strings[-CONTEXT_SIZE:] if token_strings else []
def format_context(self, token_strings: List[str]) -> List[str]:
"""Format context for TiMBL (list of token strings, padded with underscores)"""
while len(token_strings) < CONTEXT_SIZE:
token_strings.insert(0, '_')
return token_strings[-CONTEXT_SIZE:]
def predict_next_token(self, text: str, temperature: float = 1.0) -> Tuple[str, float]:
"""Predict next token for given text via remote server"""
self._ensure_connection()
if not self.model_loaded:
return " [server not connected]", 0.0
token_strings = self.tokenize_context(text)
prompt_words = self.format_context(token_strings)
try:
classlabel, distribution, distance = self.server.classify(prompt_words)
except Exception as e:
print(f"Server classification error: {e}")
self.model_loaded = False
return " [server error]", 0.0
if distribution and temperature > 0.0:
tokens = list(distribution.keys())
probs = np.array([float(distribution[t]) for t in tokens])
if temperature != 1.0:
log_probs = np.log(probs + 1e-10)
scaled_log_probs = log_probs / temperature
scaled_log_probs = scaled_log_probs - np.max(scaled_log_probs)
exp_probs = np.exp(scaled_log_probs)
probs = exp_probs / np.sum(exp_probs)
predicted_token_str = np.random.choice(tokens, p=probs)
confidence = float(distribution[predicted_token_str])
else:
predicted_token_str = classlabel
confidence = float(distribution.get(predicted_token_str, 0.0)) if distribution else 0.0
predicted_token = self.tokenizer.convert_tokens_to_string([predicted_token_str])
return predicted_token, confidence
def generate(self, prompt: str, max_tokens: int = 50,
min_confidence: float = 0.0,
temperature: float = 1.0,
stop_sequences: List[str] = None) -> Tuple[str, List[dict]]:
"""Generate text autoregressively"""
if not self.model_loaded:
return "Server not connected - cannot generate text", []
if stop_sequences is None:
stop_sequences = []
current_text = prompt
generated_tokens = []
token_info = []
for i in range(max_tokens):
next_token, confidence = self.predict_next_token(current_text, temperature=temperature)
if "[server" in next_token:
break
token_info.append({
'token_num': i + 1,
'token': next_token,
'confidence': confidence
})
current_text += next_token
generated_tokens.append(next_token)
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in current_text:
current_text = current_text.split(stop_seq)[0]
break
else:
continue
break
if len(generated_tokens) >= 5:
last_5 = generated_tokens[-5:]
if len(set(last_5)) <= 2:
break
if min_confidence > 0.0 and token_info:
current_text, token_info = self._filter_low_confidence_sentences(
prompt, current_text, token_info, min_confidence
)
return current_text, token_info
def _filter_low_confidence_sentences(self, prompt: str, generated_text: str,
token_info: List[dict], min_confidence: float) -> Tuple[str, List[dict]]:
"""Remove sentences (after the first) that contain tokens below confidence threshold"""
token_confidences = {i: info['confidence'] for i, info in enumerate(token_info)}
sentences = []
current_sentence = ""
current_sentence_tokens = []
for i, token_info_item in enumerate(token_info):
token = token_info_item['token']
current_sentence += token
current_sentence_tokens.append(i)
if '.' in token:
sentences.append({
'text': current_sentence,
'token_indices': current_sentence_tokens.copy()
})
current_sentence = ""
current_sentence_tokens = []
if current_sentence:
sentences.append({
'text': current_sentence,
'token_indices': current_sentence_tokens.copy()
})
filtered_text = prompt
filtered_token_info = []
for i, sentence in enumerate(sentences):
if i == 0:
filtered_text += sentence['text']
for token_idx in sentence['token_indices']:
filtered_token_info.append(token_info[token_idx])
else:
sentence_ok = all(
token_confidences[token_idx] >= min_confidence
for token_idx in sentence['token_indices']
)
if sentence_ok:
filtered_text += sentence['text']
for token_idx in sentence['token_indices']:
filtered_token_info.append(token_info[token_idx])
else:
break
return filtered_text, filtered_token_info
def get_confidence_color(confidence: float) -> str:
"""Get color for confidence level"""
conf_pct = confidence * 100
if conf_pct >= 70:
return "#22c55e"
elif conf_pct >= 50:
return "#84cc16"
elif conf_pct >= 30:
return "#eab308"
elif conf_pct >= 15:
return "#f97316"
else:
return "#ef4444"
def format_generation_output(prompt: str, generated_text: str, token_info: List[dict]) -> str:
"""Create formatted HTML output for generation with color-coded tokens"""
if not token_info:
return f"""
<div style="color: #ef4444; padding: 20px; background: #fee2e2; border-radius: 8px;">
<strong>Error:</strong> No tokens generated. Check server connection.
</div>
"""
html = f"""
<style>
.token-span {{
padding: 2px 4px;
border-radius: 3px;
cursor: help;
transition: transform 0.1s ease;
display: inline-block;
}}
.token-span:hover {{
transform: scale(1.05);
box-shadow: 0 2px 4px rgba(0,0,0,0.2);
}}
</style>
<div style="font-family: system-ui, -apple-system, sans-serif;">
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white; padding: 20px; border-radius: 8px; margin-bottom: 20px;">
<h2 style="margin: 0; font-size: 24px; color: white;">Generated Text</h2>
<p style="margin: 10px 0 0 0; font-size: 14px; color: rgba(255,255,255,0.9);">
Hover over colored text to see confidence scores
</p>
</div>
<div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;">
<div style="font-size: 16px; line-height: 1.8; color: #333;">
<span style="background: #e0f2fe; padding: 2px 4px; border-radius: 3px;">
<strong>{prompt}</strong>
</span>"""
for info in token_info:
token = info['token']
confidence = info['confidence']
conf_pct = confidence * 100
color = get_confidence_color(confidence)
token_html = token.replace('<', '&lt;').replace('>', '&gt;')
html += f"""<span class="token-span" style="background: {color}; color: white;" title="Confidence: {conf_pct:.1f}%">{token_html}</span>"""
num_tokens = len(token_info)
avg_confidence = sum(t['confidence'] for t in token_info) / len(token_info) * 100
total_chars = len(generated_text)
html += f"""
</div>
</div>
<div style="background: #f0fdf4; padding: 20px; border-radius: 8px; margin-bottom: 20px;">
<h3 style="color: #166534; margin-top: 0;">Generation Statistics</h3>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;">
<div style="background: white; padding: 15px; border-radius: 6px; border-left: 3px solid #22c55e;">
<div style="color: #666; font-size: 13px;">Tokens Generated</div>
<div style="font-size: 24px; font-weight: 600; color: #166534;">{num_tokens}</div>
</div>
<div style="background: white; padding: 15px; border-radius: 6px; border-left: 3px solid #3b82f6;">
<div style="color: #666; font-size: 13px;">Avg Confidence</div>
<div style="font-size: 24px; font-weight: 600; color: #1e40af;">{avg_confidence:.1f}%</div>
</div>
<div style="background: white; padding: 15px; border-radius: 6px; border-left: 3px solid #f59e0b;">
<div style="color: #666; font-size: 13px;">Total Characters</div>
<div style="font-size: 24px; font-weight: 600; color: #92400e;">{total_chars}</div>
</div>
</div>
</div>
<div style="background: #e0f2fe; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #0284c7;">
<p style="margin: 0; color: #0369a1; font-size: 14px;">
<strong>Server Mode:</strong> Connected to external TiMBL server
</p>
</div>
<div style="background: #fef3c7; padding: 15px; border-radius: 8px; border-left: 4px solid #f59e0b;">
<p style="margin: 0; color: #92400e; font-size: 14px;">
<strong>Color Legend:</strong>
<span style="background: #22c55e; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">70-100%</span>
<span style="background: #84cc16; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">50-70%</span>
<span style="background: #eab308; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">30-50%</span>
<span style="background: #f97316; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">15-30%</span>
<span style="background: #ef4444; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">&lt;15%</span>
</p>
</div>
</div>
"""
return html
def generate_text(prompt: str, max_tokens: int, temperature: float, min_confidence: float,
generator: OlifantServerGenerator) -> str:
"""Main function called by Gradio"""
if not prompt.strip():
return "<p style='color: #ef4444;'>Please enter a prompt to start generation.</p>"
try:
generated_text, token_info = generator.generate(
prompt=prompt,
max_tokens=max_tokens,
min_confidence=min_confidence,
temperature=temperature,
stop_sequences=["\n\n", "<|endoftext|>"]
)
return format_generation_output(prompt, generated_text, token_info)
except Exception as e:
return f"""
<div style="color: #ef4444; padding: 20px; background: #fee2e2; border-radius: 8px;">
<strong>Error:</strong> {str(e)}
</div>
"""
# Initialize generator
print(f"Connecting to TiMBL server at {TIMBL_SERVER_HOST}:{TIMBL_SERVER_PORT}...")
generator = OlifantServerGenerator(TIMBL_SERVER_HOST, TIMBL_SERVER_PORT, TOKENIZER_NAME)
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="OLIFANT Text Generation") as demo:
gr.Markdown("""
# OLIFANT: Autoregressive Text Generation
**Memory-based LLM - External Server Mode**
[Paper on arXiv](https://huggingface.co/papers/2510.22317) | [GitHub Repository](https://github.com/antalvdb/olifant)
This Space connects to an external TiMBL server for inference, enabling larger models
without storage constraints.
---
""")
# Server status indicator
server_status = gr.HTML()
def get_status_html():
status_color = "#22c55e" if generator.model_loaded else "#ef4444"
status_text = "Connected" if generator.model_loaded else "Disconnected"
return f"""
<div style="background: {status_color}20; padding: 10px 15px; border-radius: 6px; border-left: 4px solid {status_color}; margin-bottom: 20px;">
<span style="color: {status_color}; font-weight: 600;">Server Status: {status_text}</span>
</div>
"""
demo.load(fn=get_status_html, outputs=server_status)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Enter your prompt",
placeholder="Once upon a time",
lines=5,
info="Enter text to continue. The model will generate a completion."
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10,
maximum=200,
value=50,
step=10,
label="Max Tokens",
info="Maximum number of tokens to generate"
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.0,
step=0.1,
label="Temperature",
info="0 = deterministic, 1 = normal, 2 = creative"
)
min_confidence_slider = gr.Slider(
minimum=0.0,
maximum=0.5,
value=0.1,
step=0.05,
label="Min Confidence",
info="Filter low-confidence sentences"
)
generate_btn = gr.Button("Generate Text", variant="primary", size="lg")
generation_output = gr.HTML(label="Generated Text")
gr.Examples(
examples=[
["Once upon a time", 50, 0.0, 0.1],
["The secret to happiness is", 30, 0.7, 0.1],
["In the distant future,", 50, 1.2, 0.05],
["Scientists recently discovered that", 40, 1.0, 0.1],
],
inputs=[prompt_input, max_tokens_slider, temperature_slider, min_confidence_slider],
label="Try these examples"
)
generate_btn.click(
fn=lambda p, m, t, c: generate_text(p, m, t, c, generator),
inputs=[prompt_input, max_tokens_slider, temperature_slider, min_confidence_slider],
outputs=generation_output
)
prompt_input.submit(
fn=lambda p, m, t, c: generate_text(p, m, t, c, generator),
inputs=[prompt_input, max_tokens_slider, temperature_slider, min_confidence_slider],
outputs=generation_output
)
gr.Markdown("""
---
### How It Works
OLIFANT generates text **autoregressively**:
1. Takes your prompt and the last N tokens as context
2. Sends context to external TiMBL server
3. Server finds similar contexts using k-nearest neighbors
4. Predicts the next token based on training data
5. Samples from the probability distribution
6. Appends the predicted token and repeats
### Key Features
- **CPU-Only**: No GPUs required
- **Explainable**: Every prediction traceable to training examples
- **External Server**: Model hosted separately, Space stays lean
- **Confidence Scores**: See model certainty for each token
---
<div style="text-align: center; color: #666; font-size: 14px;">
Memory-based learning for transparent AI
</div>
""")
if __name__ == "__main__":
demo.launch()