| | |
| | """ |
| | Local inference script for trained Quillan model |
| | Run your multimodal AI locally via Python |
| | """ |
| |
|
| | import torch |
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(__file__)) |
| |
|
| | from __init__ import QuillanSOTA, Config |
| | from data_loader import QuillanDataset |
| |
|
| | def load_trained_model(): |
| | """Load the trained Quillan model from checkpoint""" |
| | print("π§ Loading trained Quillan model...") |
| |
|
| | |
| | config = Config() |
| | model = QuillanSOTA(config) |
| |
|
| | |
| | checkpoint_path = "checkpoints/quillan_best.pt" |
| | if not os.path.exists(checkpoint_path): |
| | print(f"β Checkpoint not found: {checkpoint_path}") |
| | print("Please run training first to create checkpoints.") |
| | return None |
| |
|
| | print(f"π Loading checkpoint: {checkpoint_path}") |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| |
|
| | |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| |
|
| | print("β
Model loaded successfully!") |
| | print(f"π Best loss achieved: {checkpoint.get('loss', 'Unknown'):.4f}") |
| | print(f"π
Trained for {checkpoint.get('epoch', 'Unknown')} epochs") |
| |
|
| | return model |
| |
|
| | def simple_text_inference(model, text_input, max_length=50): |
| | """Simple text-only inference for testing""" |
| | print(f"\nπ Processing: '{text_input}'") |
| |
|
| | |
| | tokens = [min(ord(c), 999) for c in text_input] |
| | if len(tokens) < max_length: |
| | tokens.extend([0] * (max_length - len(tokens))) |
| | else: |
| | tokens = tokens[:max_length] |
| |
|
| | text_tensor = torch.tensor([tokens], dtype=torch.long) |
| |
|
| | |
| | batch_size = 1 |
| | img = torch.randn(batch_size, 3, 256, 256) |
| | aud = torch.randn(batch_size, 1, 2048) |
| | vid = torch.randn(batch_size, 3, 8, 32, 32) |
| |
|
| | print("π― Running inference...") |
| |
|
| | with torch.no_grad(): |
| | try: |
| | outputs = model(text_tensor, img, aud, vid) |
| |
|
| | |
| | if 'text' in outputs: |
| | logits = outputs['text'] |
| |
|
| | |
| | |
| | predicted_tokens = torch.argmax(logits, dim=-1) |
| |
|
| | |
| | response_chars = [] |
| | for token in predicted_tokens[0]: |
| | token_val = token.item() |
| | if token_val == 0: |
| | break |
| | elif 1 <= token_val <= 999: |
| | |
| | char_code = min(max(token_val, 32), 126) |
| | response_chars.append(chr(char_code)) |
| | |
| |
|
| | response_text = ''.join(response_chars).strip() |
| |
|
| | if response_text: |
| | print(f"π Response: {response_text}") |
| | return response_text |
| | else: |
| | print("π Response: [Generated empty response]") |
| | return "[Empty response]" |
| |
|
| | except Exception as e: |
| | print(f"β οΈ Inference error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return None |
| |
|
| | def interactive_mode(model): |
| | """Interactive chat mode""" |
| | print("\n" + "="*50) |
| | print("π QUILLAN LOCAL INFERENCE - READY!") |
| | print("="*50) |
| | print("Your multimodal AI is running locally!") |
| | print("Type 'quit' to exit") |
| | print("-" * 30) |
| |
|
| | while True: |
| | try: |
| | user_input = input("\nYou: ").strip() |
| |
|
| | if user_input.lower() in ['quit', 'exit', 'q']: |
| | print("π Goodbye!") |
| | break |
| |
|
| | if not user_input: |
| | continue |
| |
|
| | |
| | response = simple_text_inference(model, user_input) |
| |
|
| | if response: |
| | print(f"Quillan: {response}") |
| | else: |
| | print("Quillan: [Processing failed]") |
| |
|
| | except KeyboardInterrupt: |
| | print("\nπ Goodbye!") |
| | break |
| | except Exception as e: |
| | print(f"β οΈ Error: {e}") |
| |
|
| | def main(): |
| | """Main function to launch local inference""" |
| | print("π Starting Quillan Local Inference...") |
| |
|
| | |
| | model = load_trained_model() |
| | if model is None: |
| | return |
| |
|
| | |
| | interactive_mode(model) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|