Spaces:
Sleeping
Sleeping
| # main.py | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import psycopg2 | |
| from psycopg2.extras import RealDictCursor | |
| import os | |
| from dotenv import load_dotenv | |
| import google.generativeai as genai | |
| app = FastAPI() | |
| # Load environment variables and configure Genai | |
| load_dotenv() | |
| genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) | |
| class Query(BaseModel): | |
| question: str | |
| def get_gemini_response(question, prompt): | |
| model = genai.GenerativeModel('gemini-1.5-pro') | |
| response = model.generate_content([prompt, question]) | |
| return response.text.strip() # Added strip() to remove any extra whitespace | |
| sql_prompt = """ | |
| Convert the following English question to a PostgreSQL query for the Pagila DVD rental database. | |
| Only return the SQL query without any markdown formatting or explanations. | |
| The database has these main tables: | |
| - actor (actor_id, first_name, last_name) | |
| - film (film_id, title, description, release_year, rental_rate, length, rating) | |
| - category (category_id, name) | |
| - film_category (film_id, category_id) | |
| - inventory (inventory_id, film_id, store_id) | |
| - rental (rental_id, rental_date, inventory_id, customer_id, return_date, staff_id) | |
| - customer (customer_id, first_name, last_name, email) | |
| - payment (payment_id, customer_id, staff_id, rental_id, amount, payment_date) | |
| Example queries: | |
| Q: List all actors | |
| A: SELECT * FROM actor; | |
| Q: Show top 10 most rented movies | |
| A: SELECT f.title, COUNT(r.rental_id) as rental_count FROM film f JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id GROUP BY f.title ORDER BY rental_count DESC LIMIT 10; | |
| """ | |
| def execute_sql_query(query): | |
| conn = None | |
| try: | |
| conn = psycopg2.connect( | |
| dbname=os.getenv('DB_NAME'), | |
| user=os.getenv('DB_USER'), | |
| password=os.getenv('DB_PASSWORD'), | |
| host=os.getenv('DB_HOST'), | |
| port=os.getenv('DB_PORT', '5432') | |
| ) | |
| with conn.cursor(cursor_factory=RealDictCursor) as cursor: | |
| cursor.execute(query) | |
| result = cursor.fetchall() | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Database Error: {str(e)}") | |
| finally: | |
| if conn: | |
| conn.close() | |
| async def process_query(query: Query): | |
| try: | |
| sql_query = get_gemini_response(query.question, sql_prompt) | |
| # Remove any SQL code block markers if present | |
| sql_query = sql_query.replace('```sql', '').replace('```', '').strip() | |
| result = execute_sql_query(sql_query) | |
| return {"query": sql_query, "result": result} | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) |