File size: 1,834 Bytes
a69a309
 
 
 
 
 
 
 
4a2fa0c
a69a309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import getpass
import os
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from langchain.chat_models import init_chat_model

class GraspTarget(BaseModel):
    """Grasp target position for robot"""
    # TODO: 需要根据实际的机器人坐标系进行调整, X coordinate in meters?
    x: float = Field(description="x-axis coordinates in three-dimensional space")
    y: float = Field(description="y-axis coordinates in three-dimensional space")
    z: float = Field(description="z-axis coordinates in three-dimensional space")

class GraspTargetGenerator:
    """A class to generate grasp targets using LLM"""
    
    def __init__(self, model_name: str = "gemini-2.0-flash", model_provider: str = "google_genai"):
        """Initialize the GraspTargetGenerator with specified model"""
        # 加载环境变量
        load_dotenv()
        
        # 确保有 API key
        if not os.environ.get("GOOGLE_API_KEY"):
            os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
        
        # 初始化 LLM
        self.llm = init_chat_model(model_name, model_provider=model_provider)
        self.structured_llm = self.llm.with_structured_output(GraspTarget)
    
    def generate_grasp_target(self, task_description: str) -> list[float]:
        """Generate grasp target coordinates based on task description
        
        Args:
            task_description: Description of the task to generate coordinates for
            
        Returns:
            List of [x, y, z] coordinates
        """
        response = self.structured_llm.invoke(f"Task: {task_description}")
        return [response.x, response.y, response.z]

if __name__ == "__main__":
    generator = GraspTargetGenerator()
    print(generator.generate_grasp_target("Pick up the red ball"))