Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import LLMChain | |
| from langchain.prompts import PromptTemplate | |
| from supabase import create_client, Client | |
| from datetime import datetime | |
| from typing import List, Dict | |
| import json | |
| import uuid | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Supabase setup (replace with your Supabase URL and key) | |
| SUPABASE_URL = "https://ykkbxlbonywjmvbyfvwo.supabase.co" | |
| SUPABASE_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6Inlra2J4bGJvbnl3am12YnlmdndvIiwicm9sZSI6ImFub24iLCJpYXQiOjE3Mzk5NTA2NjIsImV4cCI6MjA1NTUyNjY2Mn0.2BZul_igHKmZtQGhbwV3PvRsCikxviL8ogTKPD3XhuU" | |
| supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| # Create a request model with context, user_id, and optional thread_id | |
| class SearchQuery(BaseModel): | |
| query: str | |
| context: str = None # Optional context field | |
| user_id: str # UUID string to identify the user for storing history | |
| thread_id: str = None # Optional thread_id to append to an existing thread | |
| # Create a response model for history | |
| class ConversationHistory(BaseModel): | |
| id: str # UUID as string | |
| user_id: str # UUID as string | |
| query: str | |
| response: str # Response stored as TEXT in the DB | |
| timestamp: str | |
| thread_id: str # UUID as string | |
| title: str = None # Optional title | |
| # Initialize LangChain with Groq | |
| llm = ChatGroq( | |
| temperature=0.7, | |
| model_name="mixtral-8x7b-32768", | |
| groq_api_key="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" # Replace with your actual Groq API key | |
| ) | |
| # Define prompt templates | |
| prompt_templates = { | |
| "common_threats": PromptTemplate( | |
| input_variables=["query", "context"], | |
| template=""" | |
| Context: {context} | |
| Query: {query} | |
| Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems. | |
| """ | |
| ), | |
| "general": PromptTemplate( | |
| input_variables=["query", "context"], | |
| template=""" | |
| Context: You are a cybersecurity expert with extensive experience in all sub-streams of the industry, including but not limited to network security, application security, cloud security, threat intelligence, penetration testing, and incident response. {context} | |
| Query: {query} | |
| Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context. | |
| """ | |
| ), | |
| } | |
| # Initialize chains for each prompt | |
| chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()} | |
| # Helper function to get conversation history for a user (by user_id and optionally thread_id) | |
| def get_conversation_history(user_id: str, thread_id: str = None) -> List[Dict]: | |
| try: | |
| query = supabase.table("conversation_history").select("*").eq("user_id", user_id) | |
| if thread_id: | |
| query = query.eq("thread_id", thread_id) | |
| response = query.order("timestamp", desc=True).execute() | |
| return response.data | |
| except Exception as e: | |
| print(f"Error retrieving history: {e}") | |
| return [] | |
| # Helper function to get all threads for a user (distinct thread_ids with their titles) | |
| def get_user_threads(user_id: str) -> List[Dict]: | |
| try: | |
| # Select distinct threads with their titles | |
| response = supabase.table("conversation_history")\ | |
| .select("thread_id, title")\ | |
| .eq("user_id", user_id)\ | |
| .order("timestamp", desc=True)\ | |
| .execute() | |
| # Remove duplicates while preserving order | |
| seen = set() | |
| threads = [] | |
| for item in response.data: | |
| if item["thread_id"] not in seen: | |
| seen.add(item["thread_id"]) | |
| threads.append({ | |
| "thread_id": item["thread_id"], | |
| "title": item["title"] or f"Thread {len(threads) + 1}" | |
| }) | |
| return threads | |
| except Exception as e: | |
| print(f"Error retrieving threads: {e}") | |
| return [] | |
| # Helper function to save conversation to Supabase | |
| def save_conversation(user_id: str, query: str, response: Dict, thread_id: str = None, title: str = None): | |
| try: | |
| # If no thread_id is provided, generate a new one | |
| if not thread_id: | |
| thread_id = str(uuid.uuid4()) | |
| # If no title is provided, generate a default one based on the query | |
| if not title: | |
| title = query[:50] + "..." if len(query) > 50 else query | |
| conversation = { | |
| "user_id": user_id, | |
| "query": query, | |
| "response": json.dumps(response), # Convert response Dict to string | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "thread_id": thread_id, | |
| "title": title | |
| } | |
| supabase.table("conversation_history").insert(conversation).execute() | |
| return thread_id | |
| except Exception as e: | |
| print(f"Error saving conversation: {e}") | |
| raise | |
| async def process_search(search_query: SearchQuery): | |
| try: | |
| # Validate user_id as UUID | |
| try: | |
| uuid.UUID(search_query.user_id) | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID") | |
| # Validate thread_id as UUID if provided | |
| if search_query.thread_id: | |
| try: | |
| uuid.UUID(search_query.thread_id) | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID") | |
| # Set default context if not provided | |
| base_context = search_query.context or "You are a cybersecurity expert." | |
| # Retrieve previous conversation history for context (within the same thread if thread_id is provided) | |
| history = get_conversation_history(search_query.user_id, search_query.thread_id) | |
| history_context = "\n".join([f"Previous Query: {item['query']}\nPrevious Response: {json.loads(item['response'])['Use Clear Language: Avoid ambiguity and complex wording']}" for item in history]) | |
| full_context = f"{base_context}\n{history_context}" if history_context else base_context | |
| # Default to the "general" prompt template | |
| query_type = "general" | |
| # Process the query using the general chain | |
| raw_response = chains[query_type].run(query=search_query.query, context=full_context) | |
| # Structure the response according to the desired format | |
| structured_response = { | |
| "Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.", | |
| "Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{full_context}', guiding the response to align with cybersecurity expertise.", | |
| "Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(), | |
| "Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity." | |
| } | |
| # Save the conversation to Supabase (append to existing thread or create new) | |
| thread_id = save_conversation( | |
| user_id=search_query.user_id, | |
| query=search_query.query, | |
| response=structured_response, | |
| thread_id=search_query.thread_id | |
| ) | |
| return { | |
| "status": "success", | |
| "response": structured_response, | |
| "thread_id": thread_id, | |
| "classified_type": query_type | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_history(user_id: str, thread_id: str = None): | |
| try: | |
| # Validate user_id as UUID | |
| try: | |
| uuid.UUID(user_id) | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID") | |
| # Validate thread_id as UUID if provided | |
| if thread_id: | |
| try: | |
| uuid.UUID(thread_id) | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="Invalid thread_id: must be a valid UUID") | |
| # Get conversation history for the user (optionally filtered by thread_id) | |
| history = get_conversation_history(user_id, thread_id) | |
| return { | |
| "status": "success", | |
| "history": history | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_threads(user_id: str): | |
| try: | |
| # Validate user_id as UUID | |
| try: | |
| uuid.UUID(user_id) | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="Invalid user_id: must be a valid UUID") | |
| # Get all threads for the user | |
| threads = get_user_threads(user_id) | |
| return { | |
| "status": "success", | |
| "threads": threads | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def root(): | |
| return {"message": "Search API with structured response, history, and threads is running"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |