Spaces:
No application file
No application file
sam133
Refactor: Restructure codebase with modular design patterns and fix orchestrator implementation
9529bc2 | from abc import ABC, abstractmethod | |
| from typing import Optional | |
| import json | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from ..core.robot_design import RobotDesign, RobotDesignBuilder | |
| class LLMStrategy(ABC): | |
| def generate_design(self, prompt: str) -> RobotDesign: | |
| pass | |
| class LocalLLMStrategy(LLMStrategy): | |
| def __init__(self, model_id: str = "distilgpt2"): | |
| self.model_id = model_id | |
| self.tokenizer = None | |
| self.model = None | |
| self._load_model() | |
| def _load_model(self): | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, local_files_only=True) | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id, local_files_only=True) | |
| except Exception as e: | |
| print(f"Failed to load model locally: {e}") | |
| print("Downloading model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id) | |
| def generate_design(self, prompt: str) -> RobotDesign: | |
| formatted_prompt = f"""You are a robot design expert. Create a robot design based on the following requirements: | |
| {prompt} | |
| Return ONLY a JSON object with the following structure, no other text: | |
| {{ | |
| "body": {{ | |
| "mass": 10.0, | |
| "dimensions": [1.0, 1.0, 0.5] | |
| }}, | |
| "wheels": [ | |
| {{ | |
| "radius": 0.2, | |
| "width": 0.1, | |
| "position": [0.5, 0.5, 0.2], | |
| "friction": 0.8 | |
| }}, | |
| {{ | |
| "radius": 0.2, | |
| "width": 0.1, | |
| "position": [-0.5, 0.5, 0.2], | |
| "friction": 0.8 | |
| }}, | |
| {{ | |
| "radius": 0.2, | |
| "width": 0.1, | |
| "position": [0.5, -0.5, 0.2], | |
| "friction": 0.8 | |
| }}, | |
| {{ | |
| "radius": 0.2, | |
| "width": 0.1, | |
| "position": [-0.5, -0.5, 0.2], | |
| "friction": 0.8 | |
| }} | |
| ], | |
| "motors": [ | |
| {{ | |
| "max_force": 100.0, | |
| "max_velocity": 10.0 | |
| }}, | |
| {{ | |
| "max_force": 100.0, | |
| "max_velocity": 10.0 | |
| }}, | |
| {{ | |
| "max_force": 100.0, | |
| "max_velocity": 10.0 | |
| }}, | |
| {{ | |
| "max_force": 100.0, | |
| "max_velocity": 10.0 | |
| }} | |
| ] | |
| }}""" | |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt", truncation=True) | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=500, | |
| num_return_sequences=1, | |
| temperature=0.3, | |
| do_sample=True, | |
| top_p=0.9, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| try: | |
| json_start = response.find('{') | |
| json_end = response.rfind('}') + 1 | |
| if json_start == -1 or json_end == 0: | |
| return RobotDesignBuilder.get_default_design() | |
| json_str = response[json_start:json_end] | |
| design_dict = json.loads(json_str) | |
| # Validate required fields | |
| required_fields = ["body", "wheels", "motors"] | |
| for field in required_fields: | |
| if field not in design_dict: | |
| return RobotDesignBuilder.get_default_design() | |
| # Create design using builder | |
| builder = RobotDesignBuilder() | |
| builder.set_body( | |
| design_dict["body"]["mass"], | |
| design_dict["body"]["dimensions"] | |
| ) | |
| for wheel in design_dict["wheels"]: | |
| builder.add_wheel( | |
| wheel["radius"], | |
| wheel["width"], | |
| wheel["position"], | |
| wheel["friction"] | |
| ) | |
| for motor in design_dict["motors"]: | |
| builder.add_motor( | |
| motor["max_force"], | |
| motor["max_velocity"] | |
| ) | |
| return builder.build() | |
| except Exception as e: | |
| print(f"Failed to parse LLM response: {e}") | |
| return RobotDesignBuilder.get_default_design() | |
| class LLMFactory: | |
| def create_llm(llm_type: str = "local", model_id: Optional[str] = None) -> LLMStrategy: | |
| if llm_type.lower() == "local": | |
| return LocalLLMStrategy(model_id or "distilgpt2") | |
| else: | |
| raise ValueError(f"Unknown LLM type: {llm_type}") |