Spaces:
Runtime error
Runtime error
| import os | |
| from flask import Flask, request, jsonify | |
| from dotenv import load_dotenv | |
| from google.cloud import aiplatform | |
| from google.protobuf import json_format | |
| from google.protobuf.struct_pb2 import Value | |
| from prompt import create_prompt | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| app = Flask(__name__) | |
| # Configure Google Cloud credentials | |
| PROJECT_ID = os.getenv("GCP_PROJECT_ID") | |
| LOCATION = os.getenv("GCP_LOCATION") | |
| def chat_handler(): | |
| """ | |
| Handles chat requests by sending them to the Gemini model. | |
| """ | |
| # Get user query from the request body | |
| data = request.get_json() | |
| user_query = data.get("query") | |
| if not user_query: | |
| return jsonify({"error": "Query not provided"}), 400 | |
| # Create a prompt for the model | |
| prompt = create_prompt(user_query) | |
| # Initialize Vertex AI client | |
| aiplatform.init(project=PROJECT_ID, location=LOCATION) | |
| # Set up the prediction client | |
| client_options = {"api_endpoint": f"{LOCATION}-aiplatform.googleapis.com"} | |
| client = aiplatform.gapic.PredictionServiceClient(client_options=client_options) | |
| # Define the model endpoint | |
| endpoint = f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/gemini-1.0-pro-vision-001" | |
| # Create the instance payload | |
| instance = {"prompt": prompt} | |
| instances = [json_format.ParseDict(instance, Value())] | |
| # Set model parameters | |
| parameters = { | |
| "temperature": 0.2, | |
| "maxOutputTokens": 256, | |
| "topP": 0.8, | |
| "topK": 40, | |
| } | |
| try: | |
| # Send the request to the model | |
| response = client.predict( | |
| endpoint=endpoint, instances=instances, parameters=parameters | |
| ) | |
| # Extract the model's response | |
| model_response = response.predictions[0] | |
| return jsonify({"response": model_response}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| app.run(debug=True, port=5000) | |