Spaces:
Build error
Build error
| import getpass | |
| import os | |
| from dotenv import load_dotenv | |
| from pydantic import BaseModel, Field | |
| from langchain.chat_models import init_chat_model | |
| class GraspTarget(BaseModel): | |
| """Grasp target position for robot""" | |
| # TODO: 需要根据实际的机器人坐标系进行调整, X coordinate in meters? | |
| x: float = Field(description="x-axis coordinates in three-dimensional space") | |
| y: float = Field(description="y-axis coordinates in three-dimensional space") | |
| z: float = Field(description="z-axis coordinates in three-dimensional space") | |
| class GraspTargetGenerator: | |
| """A class to generate grasp targets using LLM""" | |
| def __init__(self, model_name: str = "gemini-2.0-flash", model_provider: str = "google_genai"): | |
| """Initialize the GraspTargetGenerator with specified model""" | |
| # 加载环境变量 | |
| load_dotenv() | |
| # 确保有 API key | |
| if not os.environ.get("GOOGLE_API_KEY"): | |
| os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ") | |
| # 初始化 LLM | |
| self.llm = init_chat_model(model_name, model_provider=model_provider) | |
| self.structured_llm = self.llm.with_structured_output(GraspTarget) | |
| def generate_grasp_target(self, task_description: str) -> list[float]: | |
| """Generate grasp target coordinates based on task description | |
| Args: | |
| task_description: Description of the task to generate coordinates for | |
| Returns: | |
| List of [x, y, z] coordinates | |
| """ | |
| response = self.structured_llm.invoke(f"Task: {task_description}") | |
| return [response.x, response.y, response.z] | |
| if __name__ == "__main__": | |
| generator = GraspTargetGenerator() | |
| print(generator.generate_grasp_target("Pick up the red ball")) |