| import json |
| import base64 |
| import time |
| import httpx |
| from cryptography.hazmat.primitives import hashes |
| from cryptography.hazmat.primitives.asymmetric import padding |
| from cryptography.hazmat.primitives.serialization import load_pem_private_key |
|
|
| |
| def create_jwt(client_email, private_key): |
| |
| header = json.dumps({ |
| "alg": "RS256", |
| "typ": "JWT" |
| }).encode() |
|
|
| |
| now = int(time.time()) |
| payload = json.dumps({ |
| "iss": client_email, |
| "scope": "https://www.googleapis.com/auth/cloud-platform", |
| "aud": "https://oauth2.googleapis.com/token", |
| "exp": now + 3600, |
| "iat": now |
| }).encode() |
|
|
| |
| segments = [ |
| base64.urlsafe_b64encode(header).rstrip(b'='), |
| base64.urlsafe_b64encode(payload).rstrip(b'=') |
| ] |
|
|
| |
| signing_input = b'.'.join(segments) |
| private_key = load_pem_private_key(private_key.encode(), password=None) |
| signature = private_key.sign( |
| signing_input, |
| padding.PKCS1v15(), |
| hashes.SHA256() |
| ) |
|
|
| segments.append(base64.urlsafe_b64encode(signature).rstrip(b'=')) |
| return b'.'.join(segments).decode() |
|
|
| def get_access_token(client_email, private_key): |
| jwt = create_jwt(client_email, private_key) |
|
|
| with httpx.Client() as client: |
| response = client.post( |
| "https://oauth2.googleapis.com/token", |
| data={ |
| "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", |
| "assertion": jwt |
| }, |
| headers={'Content-Type': "application/x-www-form-urlencoded"} |
| ) |
| response.raise_for_status() |
| return response.json()["access_token"] |
|
|
| def ask_stream(prompt, client_email, private_key, project_id, engine): |
| payload = { |
| "contents": [ |
| { |
| "role": "user", |
| "parts": [ |
| { |
| "text": prompt |
| } |
| ] |
| } |
| ], |
| "system_instruction": { |
| "parts": [ |
| { |
| "text": "You are Gemini, a large language model trained by Google. Respond conversationally" |
| } |
| ] |
| }, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| "generationConfig": { |
| "temperature": 0.5, |
| "max_output_tokens": 256, |
| "top_k": 40, |
| "top_p": 0.95 |
| }, |
| "tools": [ |
| { |
| "function_declarations": [ |
| { |
| "name": "get_search_results", |
| "description": "Search Google to enhance knowledge.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "prompt": { |
| "type": "string", |
| "description": "The prompt to search." |
| } |
| }, |
| "required": [ |
| "prompt" |
| ] |
| } |
| }, |
| { |
| "name": "get_url_content", |
| "description": "Get the webpage content of a URL", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "url": { |
| "type": "string", |
| "description": "the URL to request" |
| } |
| }, |
| "required": [ |
| "url" |
| ] |
| } |
| } |
| ] |
| } |
| ], |
| "tool_config": { |
| "function_calling_config": { |
| "mode": "AUTO" |
| } |
| } |
| } |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| access_token = get_access_token(client_email, private_key) |
| headers = { |
| 'Authorization': f"Bearer {access_token}", |
| 'Content-Type': "application/json" |
| } |
|
|
| MODEL_ID = engine |
| PROJECT_ID = project_id |
| stream = "generateContent" |
| with httpx.Client() as client: |
| response = client.post( |
| f"https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}", |
| json=payload, |
| headers=headers, |
| timeout=600, |
| ) |
| response.raise_for_status() |
| return response.json() |
|
|
| |
| client_email, private_key, project_id = SERVICE_ACCOUNT_KEY["client_email"], SERVICE_ACCOUNT_KEY["private_key"], SERVICE_ACCOUNT_KEY["project_id"] |
| engine = "gemini-1.5-pro" |
| user_input = input("请输入您的问题: ") |
| result = ask_stream(user_input, client_email, private_key, project_id, engine) |
| print(json.dumps(result, ensure_ascii=False, indent=2)) |