|
|
| """
|
| Inference script for TouchGrass models.
|
| Supports both 3B and 7B, CUDA and MPS backends.
|
| """
|
|
|
| import argparse
|
| import sys
|
| from pathlib import Path
|
|
|
| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
| from configs.touchgrass_3b_config import TOUCHGRASS_3B_CONFIG
|
| from configs.touchgrass_7b_config import TOUCHGRASS_7B_CONFIG
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description="Run inference with TouchGrass model")
|
| parser.add_argument(
|
| "--model_path",
|
| type=str,
|
| required=True,
|
| help="Path to trained model checkpoint",
|
| )
|
| parser.add_argument(
|
| "--model_size",
|
| type=str,
|
| choices=["3b", "7b"],
|
| default="3b",
|
| help="Model size for config",
|
| )
|
| parser.add_argument(
|
| "--device",
|
| type=str,
|
| default="cuda",
|
| choices=["cuda", "mps", "cpu"],
|
| help="Device to run on",
|
| )
|
| parser.add_argument(
|
| "--use_mps",
|
| action="store_true",
|
| help="Use MPS backend (Apple Silicon)",
|
| )
|
| parser.add_argument(
|
| "--quantization",
|
| type=str,
|
| choices=[None, "int8", "int4"],
|
| default=None,
|
| help="Apply quantization (CUDA only)",
|
| )
|
| parser.add_argument(
|
| "--flash_attention",
|
| action="store_true",
|
| help="Use Flash Attention 2 (CUDA only)",
|
| )
|
| parser.add_argument(
|
| "--torch_compile",
|
| action="store_true",
|
| help="Use torch.compile",
|
| )
|
| parser.add_argument(
|
| "--prompt",
|
| type=str,
|
| default=None,
|
| help="Input prompt for generation",
|
| )
|
| parser.add_argument(
|
| "--interactive",
|
| action="store_true",
|
| help="Run in interactive mode",
|
| )
|
| parser.add_argument(
|
| "--instrument",
|
| type=str,
|
| default=None,
|
| choices=["guitar", "piano", "drums", "vocals", "theory", "dj", "general"],
|
| help="Instrument context for system prompt",
|
| )
|
| parser.add_argument(
|
| "--skill_level",
|
| type=str,
|
| default="beginner",
|
| choices=["beginner", "intermediate", "advanced"],
|
| help="User skill level",
|
| )
|
| parser.add_argument(
|
| "--max_new_tokens",
|
| type=int,
|
| default=200,
|
| help="Maximum new tokens to generate",
|
| )
|
| parser.add_argument(
|
| "--temperature",
|
| type=float,
|
| default=0.8,
|
| help="Sampling temperature",
|
| )
|
| parser.add_argument(
|
| "--top_p",
|
| type=float,
|
| default=0.9,
|
| help="Top-p sampling",
|
| )
|
| parser.add_argument(
|
| "--repetition_penalty",
|
| type=float,
|
| default=1.1,
|
| help="Repetition penalty",
|
| )
|
| return parser.parse_args()
|
|
|
|
|
| def get_system_prompt(instrument: str, skill_level: str) -> str:
|
| """Get system prompt based on instrument and skill level."""
|
| base_prompt = """You are Touch Grass 🌿, a warm, encouraging, and knowledgeable music assistant.
|
|
|
| You help people with:
|
| - Learning instruments (guitar, bass, piano, keys, drums, vocals)
|
| - Understanding music theory at any level
|
| - Writing songs (lyrics, chord progressions, structure)
|
| - Ear training and developing musicality
|
| - DJ skills and music production
|
| - Genre knowledge and music history
|
|
|
| Your personality:
|
| - Patient and encouraging — learning music is hard and takes time
|
| - Adapt to the learner's level automatically — simpler for beginners, deeper for advanced
|
| - When someone is frustrated, acknowledge it warmly before helping
|
| - Use tabs, chord diagrams, and notation when helpful
|
| - Make learning fun, not intimidating
|
| - Celebrate small wins
|
|
|
| When generating tabs use this format:
|
| [TAB]
|
| e|---------|
|
| B|---------|
|
| G|---------|
|
| D|---------|
|
| A|---------|
|
| E|---------|
|
| [/TAB]
|
|
|
| When showing chord progressions use: [PROGRESSION]I - IV - V - I[/PROGRESSION]"""
|
|
|
|
|
| instrument_additions = {
|
| "guitar": "\n\nYou specialize in guitar and bass. You know:\n- All chord shapes (open, barre, power chords)\n- Tablature and fingerpicking patterns\n- Strumming and picking techniques\n- Guitar-specific theory (CAGED system, pentatonic scales)",
|
| "piano": "\n\nYou specialize in piano and keyboards. You know:\n- Hand position and fingerings\n- Sheet music reading\n- Scales and arpeggios\n- Chord voicings and inversions\n- Pedaling techniques",
|
| "drums": "\n\nYou specialize in drums and percussion. You know:\n- Drum set setup and tuning\n- Basic grooves and fills\n- Reading drum notation\n- Rhythm and timing\n- Different drumming styles",
|
| "vocals": "\n\nYou specialize in vocals and singing. You know:\n- Breathing techniques\n- Vocal warm-ups\n- Pitch and intonation\n- Vocal registers and range\n- Mic technique",
|
| "theory": "\n\nYou specialize in music theory and composition. You know:\n- Harmony and chord progressions\n- Scales and modes\n- Rhythm and time signatures\n- Song structure\n- Ear training",
|
| "dj": "\n\nYou specialize in DJing and production. You know:\n- Beatmatching and mixing\n- EQ and compression\n- DAW software\n- Sound design\n- Genre-specific techniques",
|
| }
|
|
|
| if instrument in instrument_additions:
|
| base_prompt += instrument_additions[instrument]
|
|
|
|
|
| if skill_level == "beginner":
|
| base_prompt += "\n\nYou are speaking to a BEGINNER. Use simple language, avoid jargon, break concepts into small steps, and be extra encouraging."
|
| elif skill_level == "advanced":
|
| base_prompt += "\n\nYou are speaking to an ADVANCED musician. Use technical terms freely, dive deep into nuances, and challenge them with sophisticated concepts."
|
|
|
| return base_prompt
|
|
|
|
|
| def load_model_and_tokenizer(args):
|
| """Load model and tokenizer with appropriate optimizations."""
|
|
|
| if args.model_size == "3b":
|
| config_dict = TOUCHGRASS_3B_CONFIG
|
| else:
|
| config_dict = TOUCHGRASS_7B_CONFIG
|
|
|
|
|
| if args.device == "cuda" and torch.cuda.is_available():
|
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| elif args.device == "mps":
|
| dtype = torch.float32
|
| else:
|
| dtype = torch.float32
|
|
|
| print(f"Loading model from {args.model_path}")
|
| print(f"Device: {args.device}, Dtype: {dtype}")
|
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(
|
| args.model_path,
|
| trust_remote_code=True,
|
| )
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| args.model_path,
|
| torch_dtype=dtype,
|
| trust_remote_code=True,
|
| device_map="auto" if args.device != "cpu" else None,
|
| )
|
|
|
|
|
| if args.device == "cpu":
|
| model = model.cpu()
|
| elif args.device == "cuda" and not torch.cuda.is_available():
|
| print("CUDA not available, falling back to CPU")
|
| model = model.cpu()
|
|
|
|
|
| if args.flash_attention and args.device == "cuda":
|
| print("Flash Attention 2 enabled")
|
|
|
|
|
| if args.torch_compile and args.device != "mps":
|
| print("Using torch.compile")
|
| model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
|
|
|
| model.eval()
|
|
|
| print(f"Model loaded successfully. Vocab size: {tokenizer.vocab_size}")
|
| return model, tokenizer
|
|
|
|
|
| def generate_response(
|
| model,
|
| tokenizer,
|
| prompt: str,
|
| system_prompt: str,
|
| max_new_tokens: int = 200,
|
| temperature: float = 0.8,
|
| top_p: float = 0.9,
|
| repetition_penalty: float = 1.1,
|
| ):
|
| """Generate response from model."""
|
|
|
| full_prompt = f"system\n{system_prompt}\nuser\n{prompt}\nassistant\n"
|
|
|
|
|
| inputs = tokenizer(
|
| full_prompt,
|
| return_tensors="pt",
|
| truncation=True,
|
| max_length=4096 - max_new_tokens,
|
| )
|
|
|
|
|
| device = next(model.parameters()).device
|
| inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
|
| with torch.no_grad():
|
| outputs = model.generate(
|
| **inputs,
|
| max_new_tokens=max_new_tokens,
|
| temperature=temperature,
|
| top_p=top_p,
|
| repetition_penalty=repetition_penalty,
|
| do_sample=True,
|
| pad_token_id=tokenizer.pad_token_id,
|
| eos_token_id=tokenizer.eos_token_id,
|
| )
|
|
|
|
|
| input_length = inputs["input_ids"].shape[1]
|
| generated_tokens = outputs[0][input_length:]
|
|
|
|
|
| response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
|
|
|
| for marker in ["system", "user", "assistant"]:
|
| if marker in response:
|
| response = response.split(marker)[0].strip()
|
|
|
| return response
|
|
|
|
|
| def interactive_mode(model, tokenizer, args):
|
| """Run interactive chat mode."""
|
| system_prompt = get_system_prompt(args.instrument or "general", args.skill_level)
|
|
|
| print("\n" + "="*60)
|
| print("Touch Grass 🌿 Interactive Mode")
|
| print("="*60)
|
| print(f"Instrument: {args.instrument or 'general'}")
|
| print(f"Skill level: {args.skill_level}")
|
| print("\nType your questions. Type 'quit' or 'exit' to end.")
|
| print("="*60 + "\n")
|
|
|
| while True:
|
| try:
|
| user_input = input("You: ").strip()
|
| if user_input.lower() in ["quit", "exit", "q"]:
|
| print("Goodbye! Keep making music! 🎵")
|
| break
|
|
|
| if not user_input:
|
| continue
|
|
|
| print("\nTouch Grass: ", end="", flush=True)
|
| response = generate_response(
|
| model,
|
| tokenizer,
|
| user_input,
|
| system_prompt,
|
| max_new_tokens=args.max_new_tokens,
|
| temperature=args.temperature,
|
| top_p=args.top_p,
|
| repetition_penalty=args.repetition_penalty,
|
| )
|
| print(response)
|
| print()
|
|
|
| except KeyboardInterrupt:
|
| print("\n\nInterrupted. Goodbye!")
|
| break
|
| except Exception as e:
|
| print(f"\nError: {e}")
|
| continue
|
|
|
|
|
| def single_prompt_mode(model, tokenizer, args):
|
| """Run single prompt inference."""
|
| if not args.prompt:
|
| print("Error: --prompt is required for single prompt mode")
|
| sys.exit(1)
|
|
|
| system_prompt = get_system_prompt(args.instrument or "general", args.skill_level)
|
|
|
| print(f"\nPrompt: {args.prompt}\n")
|
| print("Generating...\n")
|
|
|
| response = generate_response(
|
| model,
|
| tokenizer,
|
| args.prompt,
|
| system_prompt,
|
| max_new_tokens=args.max_new_tokens,
|
| temperature=args.temperature,
|
| top_p=args.top_p,
|
| repetition_penalty=args.repetition_penalty,
|
| )
|
|
|
| print(f"Touch Grass: {response}")
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
|
|
| if args.device == "cuda" and not torch.cuda.is_available():
|
| print("CUDA not available, falling back to CPU")
|
| args.device = "cpu"
|
|
|
| if args.use_mps and args.device != "mps":
|
| args.device = "mps"
|
|
|
|
|
| model, tokenizer = load_model_and_tokenizer(args)
|
|
|
|
|
| if args.interactive:
|
| interactive_mode(model, tokenizer, args)
|
| else:
|
| single_prompt_mode(model, tokenizer, args)
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |