from abc import ABC, abstractmethod from typing import Optional import json from transformers import AutoModelForCausalLM, AutoTokenizer import torch from ..core.robot_design import RobotDesign, RobotDesignBuilder class LLMStrategy(ABC): @abstractmethod def generate_design(self, prompt: str) -> RobotDesign: pass class LocalLLMStrategy(LLMStrategy): def __init__(self, model_id: str = "distilgpt2"): self.model_id = model_id self.tokenizer = None self.model = None self._load_model() def _load_model(self): try: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, local_files_only=True) self.model = AutoModelForCausalLM.from_pretrained(self.model_id, local_files_only=True) except Exception as e: print(f"Failed to load model locally: {e}") print("Downloading model...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.model = AutoModelForCausalLM.from_pretrained(self.model_id) def generate_design(self, prompt: str) -> RobotDesign: formatted_prompt = f"""You are a robot design expert. Create a robot design based on the following requirements: {prompt} Return ONLY a JSON object with the following structure, no other text: {{ "body": {{ "mass": 10.0, "dimensions": [1.0, 1.0, 0.5] }}, "wheels": [ {{ "radius": 0.2, "width": 0.1, "position": [0.5, 0.5, 0.2], "friction": 0.8 }}, {{ "radius": 0.2, "width": 0.1, "position": [-0.5, 0.5, 0.2], "friction": 0.8 }}, {{ "radius": 0.2, "width": 0.1, "position": [0.5, -0.5, 0.2], "friction": 0.8 }}, {{ "radius": 0.2, "width": 0.1, "position": [-0.5, -0.5, 0.2], "friction": 0.8 }} ], "motors": [ {{ "max_force": 100.0, "max_velocity": 10.0 }}, {{ "max_force": 100.0, "max_velocity": 10.0 }}, {{ "max_force": 100.0, "max_velocity": 10.0 }}, {{ "max_force": 100.0, "max_velocity": 10.0 }} ] }}""" inputs = self.tokenizer(formatted_prompt, return_tensors="pt", truncation=True) outputs = self.model.generate( **inputs, max_new_tokens=500, num_return_sequences=1, temperature=0.3, do_sample=True, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) try: json_start = response.find('{') json_end = response.rfind('}') + 1 if json_start == -1 or json_end == 0: return RobotDesignBuilder.get_default_design() json_str = response[json_start:json_end] design_dict = json.loads(json_str) # Validate required fields required_fields = ["body", "wheels", "motors"] for field in required_fields: if field not in design_dict: return RobotDesignBuilder.get_default_design() # Create design using builder builder = RobotDesignBuilder() builder.set_body( design_dict["body"]["mass"], design_dict["body"]["dimensions"] ) for wheel in design_dict["wheels"]: builder.add_wheel( wheel["radius"], wheel["width"], wheel["position"], wheel["friction"] ) for motor in design_dict["motors"]: builder.add_motor( motor["max_force"], motor["max_velocity"] ) return builder.build() except Exception as e: print(f"Failed to parse LLM response: {e}") return RobotDesignBuilder.get_default_design() class LLMFactory: @staticmethod def create_llm(llm_type: str = "local", model_id: Optional[str] = None) -> LLMStrategy: if llm_type.lower() == "local": return LocalLLMStrategy(model_id or "distilgpt2") else: raise ValueError(f"Unknown LLM type: {llm_type}")