| |
|
| | """
|
| | Inference script for TouchGrass models.
|
| | Supports both 3B and 7B, CUDA and MPS backends.
|
| | """
|
| |
|
| | import argparse
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| | import torch
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| |
|
| | from configs.touchgrass_3b_config import TOUCHGRASS_3B_CONFIG
|
| | from configs.touchgrass_7b_config import TOUCHGRASS_7B_CONFIG
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(description="Run inference with TouchGrass model")
|
| | parser.add_argument(
|
| | "--model_path",
|
| | type=str,
|
| | required=True,
|
| | help="Path to trained model checkpoint",
|
| | )
|
| | parser.add_argument(
|
| | "--model_size",
|
| | type=str,
|
| | choices=["3b", "7b"],
|
| | default="3b",
|
| | help="Model size for config",
|
| | )
|
| | parser.add_argument(
|
| | "--device",
|
| | type=str,
|
| | default="cuda",
|
| | choices=["cuda", "mps", "cpu"],
|
| | help="Device to run on",
|
| | )
|
| | parser.add_argument(
|
| | "--use_mps",
|
| | action="store_true",
|
| | help="Use MPS backend (Apple Silicon)",
|
| | )
|
| | parser.add_argument(
|
| | "--quantization",
|
| | type=str,
|
| | choices=[None, "int8", "int4"],
|
| | default=None,
|
| | help="Apply quantization (CUDA only)",
|
| | )
|
| | parser.add_argument(
|
| | "--flash_attention",
|
| | action="store_true",
|
| | help="Use Flash Attention 2 (CUDA only)",
|
| | )
|
| | parser.add_argument(
|
| | "--torch_compile",
|
| | action="store_true",
|
| | help="Use torch.compile",
|
| | )
|
| | parser.add_argument(
|
| | "--prompt",
|
| | type=str,
|
| | default=None,
|
| | help="Input prompt for generation",
|
| | )
|
| | parser.add_argument(
|
| | "--interactive",
|
| | action="store_true",
|
| | help="Run in interactive mode",
|
| | )
|
| | parser.add_argument(
|
| | "--instrument",
|
| | type=str,
|
| | default=None,
|
| | choices=["guitar", "piano", "drums", "vocals", "theory", "dj", "general"],
|
| | help="Instrument context for system prompt",
|
| | )
|
| | parser.add_argument(
|
| | "--skill_level",
|
| | type=str,
|
| | default="beginner",
|
| | choices=["beginner", "intermediate", "advanced"],
|
| | help="User skill level",
|
| | )
|
| | parser.add_argument(
|
| | "--max_new_tokens",
|
| | type=int,
|
| | default=200,
|
| | help="Maximum new tokens to generate",
|
| | )
|
| | parser.add_argument(
|
| | "--temperature",
|
| | type=float,
|
| | default=0.8,
|
| | help="Sampling temperature",
|
| | )
|
| | parser.add_argument(
|
| | "--top_p",
|
| | type=float,
|
| | default=0.9,
|
| | help="Top-p sampling",
|
| | )
|
| | parser.add_argument(
|
| | "--repetition_penalty",
|
| | type=float,
|
| | default=1.1,
|
| | help="Repetition penalty",
|
| | )
|
| | return parser.parse_args()
|
| |
|
| |
|
| | def get_system_prompt(instrument: str, skill_level: str) -> str:
|
| | """Get system prompt based on instrument and skill level."""
|
| | base_prompt = """You are Touch Grass 🌿, a warm, encouraging, and knowledgeable music assistant.
|
| |
|
| | You help people with:
|
| | - Learning instruments (guitar, bass, piano, keys, drums, vocals)
|
| | - Understanding music theory at any level
|
| | - Writing songs (lyrics, chord progressions, structure)
|
| | - Ear training and developing musicality
|
| | - DJ skills and music production
|
| | - Genre knowledge and music history
|
| |
|
| | Your personality:
|
| | - Patient and encouraging — learning music is hard and takes time
|
| | - Adapt to the learner's level automatically — simpler for beginners, deeper for advanced
|
| | - When someone is frustrated, acknowledge it warmly before helping
|
| | - Use tabs, chord diagrams, and notation when helpful
|
| | - Make learning fun, not intimidating
|
| | - Celebrate small wins
|
| |
|
| | When generating tabs use this format:
|
| | [TAB]
|
| | e|---------|
|
| | B|---------|
|
| | G|---------|
|
| | D|---------|
|
| | A|---------|
|
| | E|---------|
|
| | [/TAB]
|
| |
|
| | When showing chord progressions use: [PROGRESSION]I - IV - V - I[/PROGRESSION]"""
|
| |
|
| |
|
| | instrument_additions = {
|
| | "guitar": "\n\nYou specialize in guitar and bass. You know:\n- All chord shapes (open, barre, power chords)\n- Tablature and fingerpicking patterns\n- Strumming and picking techniques\n- Guitar-specific theory (CAGED system, pentatonic scales)",
|
| | "piano": "\n\nYou specialize in piano and keyboards. You know:\n- Hand position and fingerings\n- Sheet music reading\n- Scales and arpeggios\n- Chord voicings and inversions\n- Pedaling techniques",
|
| | "drums": "\n\nYou specialize in drums and percussion. You know:\n- Drum set setup and tuning\n- Basic grooves and fills\n- Reading drum notation\n- Rhythm and timing\n- Different drumming styles",
|
| | "vocals": "\n\nYou specialize in vocals and singing. You know:\n- Breathing techniques\n- Vocal warm-ups\n- Pitch and intonation\n- Vocal registers and range\n- Mic technique",
|
| | "theory": "\n\nYou specialize in music theory and composition. You know:\n- Harmony and chord progressions\n- Scales and modes\n- Rhythm and time signatures\n- Song structure\n- Ear training",
|
| | "dj": "\n\nYou specialize in DJing and production. You know:\n- Beatmatching and mixing\n- EQ and compression\n- DAW software\n- Sound design\n- Genre-specific techniques",
|
| | }
|
| |
|
| | if instrument in instrument_additions:
|
| | base_prompt += instrument_additions[instrument]
|
| |
|
| |
|
| | if skill_level == "beginner":
|
| | base_prompt += "\n\nYou are speaking to a BEGINNER. Use simple language, avoid jargon, break concepts into small steps, and be extra encouraging."
|
| | elif skill_level == "advanced":
|
| | base_prompt += "\n\nYou are speaking to an ADVANCED musician. Use technical terms freely, dive deep into nuances, and challenge them with sophisticated concepts."
|
| |
|
| | return base_prompt
|
| |
|
| |
|
| | def load_model_and_tokenizer(args):
|
| | """Load model and tokenizer with appropriate optimizations."""
|
| |
|
| | if args.model_size == "3b":
|
| | config_dict = TOUCHGRASS_3B_CONFIG
|
| | else:
|
| | config_dict = TOUCHGRASS_7B_CONFIG
|
| |
|
| |
|
| | if args.device == "cuda" and torch.cuda.is_available():
|
| | dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| | elif args.device == "mps":
|
| | dtype = torch.float32
|
| | else:
|
| | dtype = torch.float32
|
| |
|
| | print(f"Loading model from {args.model_path}")
|
| | print(f"Device: {args.device}, Dtype: {dtype}")
|
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(
|
| | args.model_path,
|
| | trust_remote_code=True,
|
| | )
|
| | if tokenizer.pad_token is None:
|
| | tokenizer.pad_token = tokenizer.eos_token
|
| |
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | args.model_path,
|
| | torch_dtype=dtype,
|
| | trust_remote_code=True,
|
| | device_map="auto" if args.device != "cpu" else None,
|
| | )
|
| |
|
| |
|
| | if args.device == "cpu":
|
| | model = model.cpu()
|
| | elif args.device == "cuda" and not torch.cuda.is_available():
|
| | print("CUDA not available, falling back to CPU")
|
| | model = model.cpu()
|
| |
|
| |
|
| | if args.flash_attention and args.device == "cuda":
|
| | print("Flash Attention 2 enabled")
|
| |
|
| |
|
| | if args.torch_compile and args.device != "mps":
|
| | print("Using torch.compile")
|
| | model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
|
| |
|
| | model.eval()
|
| |
|
| | print(f"Model loaded successfully. Vocab size: {tokenizer.vocab_size}")
|
| | return model, tokenizer
|
| |
|
| |
|
| | def generate_response(
|
| | model,
|
| | tokenizer,
|
| | prompt: str,
|
| | system_prompt: str,
|
| | max_new_tokens: int = 200,
|
| | temperature: float = 0.8,
|
| | top_p: float = 0.9,
|
| | repetition_penalty: float = 1.1,
|
| | ):
|
| | """Generate response from model."""
|
| |
|
| | full_prompt = f"system\n{system_prompt}\nuser\n{prompt}\nassistant\n"
|
| |
|
| |
|
| | inputs = tokenizer(
|
| | full_prompt,
|
| | return_tensors="pt",
|
| | truncation=True,
|
| | max_length=4096 - max_new_tokens,
|
| | )
|
| |
|
| |
|
| | device = next(model.parameters()).device
|
| | inputs = {k: v.to(device) for k, v in inputs.items()}
|
| |
|
| |
|
| | with torch.no_grad():
|
| | outputs = model.generate(
|
| | **inputs,
|
| | max_new_tokens=max_new_tokens,
|
| | temperature=temperature,
|
| | top_p=top_p,
|
| | repetition_penalty=repetition_penalty,
|
| | do_sample=True,
|
| | pad_token_id=tokenizer.pad_token_id,
|
| | eos_token_id=tokenizer.eos_token_id,
|
| | )
|
| |
|
| |
|
| | input_length = inputs["input_ids"].shape[1]
|
| | generated_tokens = outputs[0][input_length:]
|
| |
|
| |
|
| | response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| |
|
| |
|
| | for marker in ["system", "user", "assistant"]:
|
| | if marker in response:
|
| | response = response.split(marker)[0].strip()
|
| |
|
| | return response
|
| |
|
| |
|
| | def interactive_mode(model, tokenizer, args):
|
| | """Run interactive chat mode."""
|
| | system_prompt = get_system_prompt(args.instrument or "general", args.skill_level)
|
| |
|
| | print("\n" + "="*60)
|
| | print("Touch Grass 🌿 Interactive Mode")
|
| | print("="*60)
|
| | print(f"Instrument: {args.instrument or 'general'}")
|
| | print(f"Skill level: {args.skill_level}")
|
| | print("\nType your questions. Type 'quit' or 'exit' to end.")
|
| | print("="*60 + "\n")
|
| |
|
| | while True:
|
| | try:
|
| | user_input = input("You: ").strip()
|
| | if user_input.lower() in ["quit", "exit", "q"]:
|
| | print("Goodbye! Keep making music! 🎵")
|
| | break
|
| |
|
| | if not user_input:
|
| | continue
|
| |
|
| | print("\nTouch Grass: ", end="", flush=True)
|
| | response = generate_response(
|
| | model,
|
| | tokenizer,
|
| | user_input,
|
| | system_prompt,
|
| | max_new_tokens=args.max_new_tokens,
|
| | temperature=args.temperature,
|
| | top_p=args.top_p,
|
| | repetition_penalty=args.repetition_penalty,
|
| | )
|
| | print(response)
|
| | print()
|
| |
|
| | except KeyboardInterrupt:
|
| | print("\n\nInterrupted. Goodbye!")
|
| | break
|
| | except Exception as e:
|
| | print(f"\nError: {e}")
|
| | continue
|
| |
|
| |
|
| | def single_prompt_mode(model, tokenizer, args):
|
| | """Run single prompt inference."""
|
| | if not args.prompt:
|
| | print("Error: --prompt is required for single prompt mode")
|
| | sys.exit(1)
|
| |
|
| | system_prompt = get_system_prompt(args.instrument or "general", args.skill_level)
|
| |
|
| | print(f"\nPrompt: {args.prompt}\n")
|
| | print("Generating...\n")
|
| |
|
| | response = generate_response(
|
| | model,
|
| | tokenizer,
|
| | args.prompt,
|
| | system_prompt,
|
| | max_new_tokens=args.max_new_tokens,
|
| | temperature=args.temperature,
|
| | top_p=args.top_p,
|
| | repetition_penalty=args.repetition_penalty,
|
| | )
|
| |
|
| | print(f"Touch Grass: {response}")
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| |
|
| |
|
| | if args.device == "cuda" and not torch.cuda.is_available():
|
| | print("CUDA not available, falling back to CPU")
|
| | args.device = "cpu"
|
| |
|
| | if args.use_mps and args.device != "mps":
|
| | args.device = "mps"
|
| |
|
| |
|
| | model, tokenizer = load_model_and_tokenizer(args)
|
| |
|
| |
|
| | if args.interactive:
|
| | interactive_mode(model, tokenizer, args)
|
| | else:
|
| | single_prompt_mode(model, tokenizer, args)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main() |