Spaces:
Sleeping
Sleeping
| """ | |
| OLIFANT Text Generation Demo - External Server Version | |
| Connects to an external TiMBL server for inference, keeping the HF Space lean. | |
| """ | |
| import gradio as gr | |
| from timbl import TimblServer | |
| from transformers import AutoTokenizer | |
| import os | |
| import numpy as np | |
| from typing import List, Tuple, Dict | |
| # Configuration - Set these via environment variables in HF Space settings | |
| TIMBL_SERVER_HOST = os.environ.get("TIMBL_SERVER_HOST", "localhost") | |
| TIMBL_SERVER_PORT = int(os.environ.get("TIMBL_SERVER_PORT", "7000")) | |
| TOKENIZER_NAME = "gpt2" | |
| CONTEXT_SIZE = int(os.environ.get("CONTEXT_SIZE", "4")) # Must match server model | |
| class OlifantServerGenerator: | |
| """Wrapper for OLIFANT that generates text via external TiMBL server""" | |
| def __init__(self, host: str, port: int, tokenizer_name: str): | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| self.host = host | |
| self.port = port | |
| self.server = None | |
| self.model_loaded = False | |
| self._connect() | |
| def _connect(self): | |
| """Establish connection to TiMBL server""" | |
| try: | |
| self.server = TimblServer(self.host, self.port) | |
| self.model_loaded = True | |
| print(f"Connected to TiMBL server at {self.host}:{self.port}") | |
| except Exception as e: | |
| print(f"Warning: Could not connect to TiMBL server: {e}") | |
| self.model_loaded = False | |
| def _ensure_connection(self): | |
| """Reconnect if connection was lost""" | |
| if not self.model_loaded or self.server is None: | |
| self._connect() | |
| def tokenize_context(self, text: str) -> List[str]: | |
| """Tokenize text with GPT-2 and return last N token strings""" | |
| token_strings = self.tokenizer.tokenize(text) | |
| return token_strings[-CONTEXT_SIZE:] if token_strings else [] | |
| def format_context(self, token_strings: List[str]) -> List[str]: | |
| """Format context for TiMBL (list of token strings, padded with underscores)""" | |
| while len(token_strings) < CONTEXT_SIZE: | |
| token_strings.insert(0, '_') | |
| return token_strings[-CONTEXT_SIZE:] | |
| def predict_next_token(self, text: str, temperature: float = 1.0) -> Tuple[str, float]: | |
| """Predict next token for given text via remote server""" | |
| self._ensure_connection() | |
| if not self.model_loaded: | |
| return " [server not connected]", 0.0 | |
| token_strings = self.tokenize_context(text) | |
| prompt_words = self.format_context(token_strings) | |
| try: | |
| classlabel, distribution, distance = self.server.classify(prompt_words) | |
| except Exception as e: | |
| print(f"Server classification error: {e}") | |
| self.model_loaded = False | |
| return " [server error]", 0.0 | |
| if distribution and temperature > 0.0: | |
| tokens = list(distribution.keys()) | |
| probs = np.array([float(distribution[t]) for t in tokens]) | |
| if temperature != 1.0: | |
| log_probs = np.log(probs + 1e-10) | |
| scaled_log_probs = log_probs / temperature | |
| scaled_log_probs = scaled_log_probs - np.max(scaled_log_probs) | |
| exp_probs = np.exp(scaled_log_probs) | |
| probs = exp_probs / np.sum(exp_probs) | |
| predicted_token_str = np.random.choice(tokens, p=probs) | |
| confidence = float(distribution[predicted_token_str]) | |
| else: | |
| predicted_token_str = classlabel | |
| confidence = float(distribution.get(predicted_token_str, 0.0)) if distribution else 0.0 | |
| predicted_token = self.tokenizer.convert_tokens_to_string([predicted_token_str]) | |
| return predicted_token, confidence | |
| def generate(self, prompt: str, max_tokens: int = 50, | |
| min_confidence: float = 0.0, | |
| temperature: float = 1.0, | |
| stop_sequences: List[str] = None) -> Tuple[str, List[dict]]: | |
| """Generate text autoregressively""" | |
| if not self.model_loaded: | |
| return "Server not connected - cannot generate text", [] | |
| if stop_sequences is None: | |
| stop_sequences = [] | |
| current_text = prompt | |
| generated_tokens = [] | |
| token_info = [] | |
| for i in range(max_tokens): | |
| next_token, confidence = self.predict_next_token(current_text, temperature=temperature) | |
| if "[server" in next_token: | |
| break | |
| token_info.append({ | |
| 'token_num': i + 1, | |
| 'token': next_token, | |
| 'confidence': confidence | |
| }) | |
| current_text += next_token | |
| generated_tokens.append(next_token) | |
| if stop_sequences: | |
| for stop_seq in stop_sequences: | |
| if stop_seq in current_text: | |
| current_text = current_text.split(stop_seq)[0] | |
| break | |
| else: | |
| continue | |
| break | |
| if len(generated_tokens) >= 5: | |
| last_5 = generated_tokens[-5:] | |
| if len(set(last_5)) <= 2: | |
| break | |
| if min_confidence > 0.0 and token_info: | |
| current_text, token_info = self._filter_low_confidence_sentences( | |
| prompt, current_text, token_info, min_confidence | |
| ) | |
| return current_text, token_info | |
| def _filter_low_confidence_sentences(self, prompt: str, generated_text: str, | |
| token_info: List[dict], min_confidence: float) -> Tuple[str, List[dict]]: | |
| """Remove sentences (after the first) that contain tokens below confidence threshold""" | |
| token_confidences = {i: info['confidence'] for i, info in enumerate(token_info)} | |
| sentences = [] | |
| current_sentence = "" | |
| current_sentence_tokens = [] | |
| for i, token_info_item in enumerate(token_info): | |
| token = token_info_item['token'] | |
| current_sentence += token | |
| current_sentence_tokens.append(i) | |
| if '.' in token: | |
| sentences.append({ | |
| 'text': current_sentence, | |
| 'token_indices': current_sentence_tokens.copy() | |
| }) | |
| current_sentence = "" | |
| current_sentence_tokens = [] | |
| if current_sentence: | |
| sentences.append({ | |
| 'text': current_sentence, | |
| 'token_indices': current_sentence_tokens.copy() | |
| }) | |
| filtered_text = prompt | |
| filtered_token_info = [] | |
| for i, sentence in enumerate(sentences): | |
| if i == 0: | |
| filtered_text += sentence['text'] | |
| for token_idx in sentence['token_indices']: | |
| filtered_token_info.append(token_info[token_idx]) | |
| else: | |
| sentence_ok = all( | |
| token_confidences[token_idx] >= min_confidence | |
| for token_idx in sentence['token_indices'] | |
| ) | |
| if sentence_ok: | |
| filtered_text += sentence['text'] | |
| for token_idx in sentence['token_indices']: | |
| filtered_token_info.append(token_info[token_idx]) | |
| else: | |
| break | |
| return filtered_text, filtered_token_info | |
| def get_confidence_color(confidence: float) -> str: | |
| """Get color for confidence level""" | |
| conf_pct = confidence * 100 | |
| if conf_pct >= 70: | |
| return "#22c55e" | |
| elif conf_pct >= 50: | |
| return "#84cc16" | |
| elif conf_pct >= 30: | |
| return "#eab308" | |
| elif conf_pct >= 15: | |
| return "#f97316" | |
| else: | |
| return "#ef4444" | |
| def format_generation_output(prompt: str, generated_text: str, token_info: List[dict]) -> str: | |
| """Create formatted HTML output for generation with color-coded tokens""" | |
| if not token_info: | |
| return f""" | |
| <div style="color: #ef4444; padding: 20px; background: #fee2e2; border-radius: 8px;"> | |
| <strong>Error:</strong> No tokens generated. Check server connection. | |
| </div> | |
| """ | |
| html = f""" | |
| <style> | |
| .token-span {{ | |
| padding: 2px 4px; | |
| border-radius: 3px; | |
| cursor: help; | |
| transition: transform 0.1s ease; | |
| display: inline-block; | |
| }} | |
| .token-span:hover {{ | |
| transform: scale(1.05); | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.2); | |
| }} | |
| </style> | |
| <div style="font-family: system-ui, -apple-system, sans-serif;"> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; padding: 20px; border-radius: 8px; margin-bottom: 20px;"> | |
| <h2 style="margin: 0; font-size: 24px; color: white;">Generated Text</h2> | |
| <p style="margin: 10px 0 0 0; font-size: 14px; color: rgba(255,255,255,0.9);"> | |
| Hover over colored text to see confidence scores | |
| </p> | |
| </div> | |
| <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; margin-bottom: 20px;"> | |
| <div style="font-size: 16px; line-height: 1.8; color: #333;"> | |
| <span style="background: #e0f2fe; padding: 2px 4px; border-radius: 3px;"> | |
| <strong>{prompt}</strong> | |
| </span>""" | |
| for info in token_info: | |
| token = info['token'] | |
| confidence = info['confidence'] | |
| conf_pct = confidence * 100 | |
| color = get_confidence_color(confidence) | |
| token_html = token.replace('<', '<').replace('>', '>') | |
| html += f"""<span class="token-span" style="background: {color}; color: white;" title="Confidence: {conf_pct:.1f}%">{token_html}</span>""" | |
| num_tokens = len(token_info) | |
| avg_confidence = sum(t['confidence'] for t in token_info) / len(token_info) * 100 | |
| total_chars = len(generated_text) | |
| html += f""" | |
| </div> | |
| </div> | |
| <div style="background: #f0fdf4; padding: 20px; border-radius: 8px; margin-bottom: 20px;"> | |
| <h3 style="color: #166534; margin-top: 0;">Generation Statistics</h3> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;"> | |
| <div style="background: white; padding: 15px; border-radius: 6px; border-left: 3px solid #22c55e;"> | |
| <div style="color: #666; font-size: 13px;">Tokens Generated</div> | |
| <div style="font-size: 24px; font-weight: 600; color: #166534;">{num_tokens}</div> | |
| </div> | |
| <div style="background: white; padding: 15px; border-radius: 6px; border-left: 3px solid #3b82f6;"> | |
| <div style="color: #666; font-size: 13px;">Avg Confidence</div> | |
| <div style="font-size: 24px; font-weight: 600; color: #1e40af;">{avg_confidence:.1f}%</div> | |
| </div> | |
| <div style="background: white; padding: 15px; border-radius: 6px; border-left: 3px solid #f59e0b;"> | |
| <div style="color: #666; font-size: 13px;">Total Characters</div> | |
| <div style="font-size: 24px; font-weight: 600; color: #92400e;">{total_chars}</div> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="background: #e0f2fe; padding: 15px; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #0284c7;"> | |
| <p style="margin: 0; color: #0369a1; font-size: 14px;"> | |
| <strong>Server Mode:</strong> Connected to external TiMBL server | |
| </p> | |
| </div> | |
| <div style="background: #fef3c7; padding: 15px; border-radius: 8px; border-left: 4px solid #f59e0b;"> | |
| <p style="margin: 0; color: #92400e; font-size: 14px;"> | |
| <strong>Color Legend:</strong> | |
| <span style="background: #22c55e; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">70-100%</span> | |
| <span style="background: #84cc16; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">50-70%</span> | |
| <span style="background: #eab308; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">30-50%</span> | |
| <span style="background: #f97316; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;">15-30%</span> | |
| <span style="background: #ef4444; color: white; padding: 2px 8px; border-radius: 3px; margin: 0 4px;"><15%</span> | |
| </p> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| def generate_text(prompt: str, max_tokens: int, temperature: float, min_confidence: float, | |
| generator: OlifantServerGenerator) -> str: | |
| """Main function called by Gradio""" | |
| if not prompt.strip(): | |
| return "<p style='color: #ef4444;'>Please enter a prompt to start generation.</p>" | |
| try: | |
| generated_text, token_info = generator.generate( | |
| prompt=prompt, | |
| max_tokens=max_tokens, | |
| min_confidence=min_confidence, | |
| temperature=temperature, | |
| stop_sequences=["\n\n", "<|endoftext|>"] | |
| ) | |
| return format_generation_output(prompt, generated_text, token_info) | |
| except Exception as e: | |
| return f""" | |
| <div style="color: #ef4444; padding: 20px; background: #fee2e2; border-radius: 8px;"> | |
| <strong>Error:</strong> {str(e)} | |
| </div> | |
| """ | |
| # Initialize generator | |
| print(f"Connecting to TiMBL server at {TIMBL_SERVER_HOST}:{TIMBL_SERVER_PORT}...") | |
| generator = OlifantServerGenerator(TIMBL_SERVER_HOST, TIMBL_SERVER_PORT, TOKENIZER_NAME) | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="OLIFANT Text Generation") as demo: | |
| gr.Markdown(""" | |
| # OLIFANT: Autoregressive Text Generation | |
| **Memory-based LLM - External Server Mode** | |
| [Paper on arXiv](https://huggingface.co/papers/2510.22317) | [GitHub Repository](https://github.com/antalvdb/olifant) | |
| This Space connects to an external TiMBL server for inference, enabling larger models | |
| without storage constraints. | |
| --- | |
| """) | |
| # Server status indicator | |
| server_status = gr.HTML() | |
| def get_status_html(): | |
| status_color = "#22c55e" if generator.model_loaded else "#ef4444" | |
| status_text = "Connected" if generator.model_loaded else "Disconnected" | |
| return f""" | |
| <div style="background: {status_color}20; padding: 10px 15px; border-radius: 6px; border-left: 4px solid {status_color}; margin-bottom: 20px;"> | |
| <span style="color: {status_color}; font-weight: 600;">Server Status: {status_text}</span> | |
| </div> | |
| """ | |
| demo.load(fn=get_status_html, outputs=server_status) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Once upon a time", | |
| lines=5, | |
| info="Enter text to continue. The model will generate a completion." | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=50, | |
| step=10, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.0, | |
| step=0.1, | |
| label="Temperature", | |
| info="0 = deterministic, 1 = normal, 2 = creative" | |
| ) | |
| min_confidence_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=0.5, | |
| value=0.1, | |
| step=0.05, | |
| label="Min Confidence", | |
| info="Filter low-confidence sentences" | |
| ) | |
| generate_btn = gr.Button("Generate Text", variant="primary", size="lg") | |
| generation_output = gr.HTML(label="Generated Text") | |
| gr.Examples( | |
| examples=[ | |
| ["Once upon a time", 50, 0.0, 0.1], | |
| ["The secret to happiness is", 30, 0.7, 0.1], | |
| ["In the distant future,", 50, 1.2, 0.05], | |
| ["Scientists recently discovered that", 40, 1.0, 0.1], | |
| ], | |
| inputs=[prompt_input, max_tokens_slider, temperature_slider, min_confidence_slider], | |
| label="Try these examples" | |
| ) | |
| generate_btn.click( | |
| fn=lambda p, m, t, c: generate_text(p, m, t, c, generator), | |
| inputs=[prompt_input, max_tokens_slider, temperature_slider, min_confidence_slider], | |
| outputs=generation_output | |
| ) | |
| prompt_input.submit( | |
| fn=lambda p, m, t, c: generate_text(p, m, t, c, generator), | |
| inputs=[prompt_input, max_tokens_slider, temperature_slider, min_confidence_slider], | |
| outputs=generation_output | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### How It Works | |
| OLIFANT generates text **autoregressively**: | |
| 1. Takes your prompt and the last N tokens as context | |
| 2. Sends context to external TiMBL server | |
| 3. Server finds similar contexts using k-nearest neighbors | |
| 4. Predicts the next token based on training data | |
| 5. Samples from the probability distribution | |
| 6. Appends the predicted token and repeats | |
| ### Key Features | |
| - **CPU-Only**: No GPUs required | |
| - **Explainable**: Every prediction traceable to training examples | |
| - **External Server**: Model hosted separately, Space stays lean | |
| - **Confidence Scores**: See model certainty for each token | |
| --- | |
| <div style="text-align: center; color: #666; font-size: 14px;"> | |
| Memory-based learning for transparent AI | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |