Spaces:
Running
Running
| """ | |
| Simple Inference Script for TinyLlama | |
| This script demonstrates how to use a fine-tuned TinyLlama model for text generation | |
| without requiring all the training dependencies. | |
| """ | |
| import os | |
| import argparse | |
| import json | |
| import time | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Run inference with a TinyLlama model") | |
| parser.add_argument( | |
| "--model_path", | |
| type=str, | |
| default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| help="Path to the model directory or HuggingFace model name" | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default=None, | |
| help="Text prompt for generation" | |
| ) | |
| parser.add_argument( | |
| "--prompt_file", | |
| type=str, | |
| default=None, | |
| help="File containing multiple prompts (one per line)" | |
| ) | |
| parser.add_argument( | |
| "--max_new_tokens", | |
| type=int, | |
| default=256, | |
| help="Maximum number of tokens to generate" | |
| ) | |
| parser.add_argument( | |
| "--temperature", | |
| type=float, | |
| default=0.7, | |
| help="Sampling temperature" | |
| ) | |
| parser.add_argument( | |
| "--output_file", | |
| type=str, | |
| default="generated_outputs.json", | |
| help="File to save generated outputs" | |
| ) | |
| parser.add_argument( | |
| "--interactive", | |
| action="store_true", | |
| help="Run in interactive mode" | |
| ) | |
| return parser.parse_args() | |
| def format_prompt_for_chat(prompt): | |
| """Format a prompt for chat completion""" | |
| return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" | |
| def main(): | |
| args = parse_args() | |
| try: | |
| # Import libraries here to make the error messages clearer | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| except ImportError: | |
| print("Error: Required libraries not installed.") | |
| print("Please install them with: pip install torch transformers") | |
| return | |
| print(f"Loading model from {args.model_path}...") | |
| # Load model and tokenizer | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model_path, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) | |
| # Move model to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| print(f"Model loaded successfully on {device}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return | |
| if args.interactive: | |
| print("\n=== Interactive Mode ===") | |
| print("Type 'exit' or 'quit' to end the session") | |
| print("Type your prompts and press Enter.\n") | |
| while True: | |
| user_input = input("\nYou: ") | |
| if user_input.lower() in ["exit", "quit"]: | |
| break | |
| # Format prompt for chat | |
| formatted_prompt = format_prompt_for_chat(user_input) | |
| # Tokenize input | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) | |
| # Generate response | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract just the assistant's response | |
| try: | |
| # This handles the case where the model properly follows the formatting | |
| assistant_response = full_response.split("<|im_start|>assistant\n")[1].split("<|im_end|>")[0] | |
| except: | |
| # Fallback for when the model doesn't follow formatting perfectly | |
| assistant_response = full_response.replace(user_input, "").strip() | |
| gen_time = time.time() - start_time | |
| tokens_per_second = len(outputs[0]) / gen_time | |
| print(f"\nAssistant: {assistant_response}") | |
| print(f"\n[Generated {len(outputs[0])} tokens in {gen_time:.2f}s - {tokens_per_second:.2f} tokens/s]") | |
| else: | |
| # Get prompts | |
| prompts = [] | |
| if args.prompt: | |
| prompts.append(args.prompt) | |
| elif args.prompt_file: | |
| with open(args.prompt_file, 'r', encoding='utf-8') as f: | |
| prompts = [line.strip() for line in f if line.strip()] | |
| else: | |
| print("Error: Either --prompt or --prompt_file must be provided") | |
| return | |
| results = [] | |
| print(f"Processing {len(prompts)} prompts...") | |
| for i, prompt in enumerate(prompts): | |
| print(f"Processing prompt {i+1}/{len(prompts)}") | |
| # Format prompt for chat | |
| formatted_prompt = format_prompt_for_chat(prompt) | |
| # Tokenize input | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract just the assistant's response | |
| try: | |
| assistant_response = full_response.split("<|im_start|>assistant\n")[1].split("<|im_end|>")[0] | |
| except: | |
| assistant_response = full_response.replace(prompt, "").strip() | |
| results.append({ | |
| "prompt": prompt, | |
| "response": assistant_response | |
| }) | |
| # Save results | |
| with open(args.output_file, 'w', encoding='utf-8') as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| print(f"Generated {len(results)} responses and saved to {args.output_file}") | |
| if __name__ == "__main__": | |
| main() |