Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import uuid | |
| import requests | |
| import re | |
| from fastapi import FastAPI, HTTPException, Request | |
| from pydantic import BaseModel, Field | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from functools import lru_cache | |
| from typing import Optional, Dict, Any, List | |
| from dotenv import load_dotenv | |
| # Load .env automatically from the project directory | |
| load_dotenv() | |
| # Read API key from environment | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| # Hardcoded configuration | |
| GROQ_MODEL = "moonshotai/kimi-k2-instruct-0905" # Default Groq model | |
| MAX_TOKENS = 2000 | |
| TEMPERATURE = 0.5 | |
| # Debugging: Check if API key is loaded | |
| if not GROQ_API_KEY: | |
| print("β GROQ_API_KEY is not set. Check your .env file or environment variables.") | |
| else: | |
| print(f"β GROQ_API_KEY Loaded: {GROQ_API_KEY[:10]}******") # Masked for security | |
| print(f"π¦ GROQ_MODEL Loaded: {GROQ_MODEL}") | |
| print(f"βοΈ Using parameters: MAX_TOKENS={MAX_TOKENS}, TEMPERATURE={TEMPERATURE}") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Code Generation API with Groq", | |
| description="API for generating code and explanations using Groq's LLM models", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS for frontend communication | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Update this with frontend domain in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # In-memory conversation history (use Redis/DB for production) | |
| conversation_history: Dict[str, List[Dict[str, str]]] = {} | |
| # Define request formats | |
| class PromptRequest(BaseModel): | |
| prompt: str = Field(..., description="The user's prompt or question") | |
| session_id: Optional[str] = Field(None, description="Session ID for conversation history") | |
| response_type: Optional[str] = Field("both", description="Type of response: 'code', 'explanation', or 'both'") | |
| class HistoryRequest(BaseModel): | |
| session_id: str = Field(..., description="Session ID to retrieve or clear history") | |
| def classify_message(message: str) -> str: | |
| """Classify whether the message is conversational or code-related.""" | |
| # Convert message to lowercase for comparison | |
| message_lower = message.lower().strip() | |
| # List of common conversational greetings and phrases | |
| conversational_phrases = [ | |
| "hi", "hello", "hey", "hi there", "hello there", "hey there", | |
| "how are you", "good morning", "good afternoon", "good evening", | |
| "what's up", "how's it going", "nice to meet you", "bye", "goodbye", | |
| "thank you", "thanks", "ok", "okay", "yes", "no", "maybe", | |
| "help", "who are you", "what can you do", "what are you", | |
| "tell me about yourself" | |
| ] | |
| # Check if the message is a question or conversation | |
| if any(message_lower.startswith(phrase) for phrase in conversational_phrases) or \ | |
| any(phrase in message_lower for phrase in conversational_phrases[:10]) or \ | |
| (message_lower.endswith("?") and len(message_lower.split()) <= 8): | |
| return "conversation" | |
| # Check for code-related keywords | |
| code_keywords = ["code", "function", "script", "program", "algorithm", "implement", | |
| "write", "create", "python", "javascript", "java", "c++"] | |
| if any(keyword in message_lower for keyword in code_keywords): | |
| return "code" | |
| # If in doubt, treat as conversation | |
| return "conversation" | |
| # API call function with retry and improved error handling | |
| def generate_response_groq(messages: List[Dict[str, str]]) -> str: | |
| """Sends messages to Groq API and returns the generated response.""" | |
| if not GROQ_API_KEY: | |
| raise HTTPException(status_code=500, detail="GROQ_API_KEY is missing.") | |
| url = "https://api.groq.com/openai/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {GROQ_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": GROQ_MODEL, | |
| "messages": messages, | |
| "temperature": TEMPERATURE, | |
| "max_tokens": MAX_TOKENS, | |
| } | |
| for attempt in range(3): # Retry logic | |
| try: | |
| print(f"π Attempt {attempt + 1} - Sending request to Groq API") | |
| response = requests.post(url, headers=headers, json=payload, timeout=60) | |
| print(f"π Status Code: {response.status_code}") | |
| if response.status_code == 200: | |
| result = response.json() | |
| if "choices" in result and len(result["choices"]) > 0: | |
| generated_text = result["choices"][0]["message"]["content"] | |
| return generated_text | |
| return "No response generated" | |
| elif response.status_code == 401: # Unauthorized (Invalid API key) | |
| print("β Authentication error: Invalid API Key") | |
| raise HTTPException(status_code=401, detail="Invalid API Key. Check your GROQ_API_KEY.") | |
| elif response.status_code == 429: # Rate limit error | |
| print("β οΈ Rate limited, retrying...") | |
| time.sleep(2 ** attempt) # Exponential backoff | |
| continue | |
| elif response.status_code == 503: # Service unavailable | |
| print("β οΈ Service unavailable, retrying...") | |
| time.sleep(2 ** attempt) | |
| continue | |
| else: | |
| error_detail = "Unknown error" | |
| try: | |
| error_data = response.json() | |
| error_detail = error_data.get("error", {}).get("message", str(error_data)) | |
| except: | |
| error_detail = response.text | |
| print(f"β API Error: {error_detail}") | |
| if attempt == 2: # Last attempt | |
| raise HTTPException(status_code=response.status_code, | |
| detail=f"Groq API Error: {error_detail}") | |
| except requests.exceptions.Timeout: | |
| print("β οΈ Request timed out, retrying...") | |
| if attempt == 2: # Last attempt | |
| raise HTTPException(status_code=504, detail="Request timed out") | |
| except requests.exceptions.ConnectionError: | |
| print("β οΈ Connection error, retrying...") | |
| if attempt == 2: # Last attempt | |
| raise HTTPException(status_code=503, detail="Could not connect to Groq API") | |
| except Exception as e: | |
| print(f"β Unexpected error: {str(e)}") | |
| if attempt == 2: # Last attempt | |
| raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") | |
| # Wait before retry (except on last attempt) | |
| if attempt < 2: | |
| time.sleep(2 ** attempt) | |
| raise HTTPException(status_code=500, detail="Failed to get response after multiple attempts") | |
| # Helper function to process and format the model's response | |
| def process_response(raw_response: str, response_type: str) -> Dict[str, Any]: | |
| """Process and format the model's response based on the requested type.""" | |
| # For conversational responses, don't try to extract code | |
| if response_type == "conversation": | |
| return {"response": raw_response} | |
| elif response_type == "code": | |
| # Extract code blocks with regex | |
| code_match = re.search(r"```(?:python|javascript|java|cpp|c\+\+)?\n(.*?)\n```", raw_response, re.DOTALL) | |
| if code_match: | |
| return {"generated_code": code_match.group(1).strip()} | |
| # If no code block found, return the whole response as code | |
| return {"generated_code": raw_response} | |
| elif response_type == "explanation": | |
| # Remove code blocks | |
| explanation = re.sub(r"```(?:\w+)?\n.*?\n```", "", raw_response, flags=re.DOTALL).strip() | |
| return {"explanation": explanation} | |
| else: # "both" | |
| code = None | |
| explanation = raw_response | |
| # Extract code blocks | |
| code_match = re.search(r"```(?:python|javascript|java|cpp|c\+\+)?\n(.*?)\n```", raw_response, re.DOTALL) | |
| if code_match: | |
| code = code_match.group(1).strip() | |
| # Remove code blocks from explanation | |
| explanation = re.sub(r"```(?:\w+)?\n.*?\n```", "", raw_response, flags=re.DOTALL).strip() | |
| return { | |
| "response": raw_response, | |
| "generated_code": code, | |
| "explanation": explanation | |
| } | |
| # API route for generating responses | |
| async def generate_response(request: PromptRequest): | |
| """Handles incoming user requests, maintains session history, and calls Groq model.""" | |
| try: | |
| session_id = request.session_id or str(uuid.uuid4()) | |
| if session_id not in conversation_history: | |
| conversation_history[session_id] = [] | |
| # Classify the message type first | |
| message_type = classify_message(request.prompt) | |
| # Build messages array for Groq API (OpenAI format) | |
| messages = [] | |
| # Add system message based on response type | |
| if message_type == "conversation": | |
| system_prompt = "You are a helpful and friendly AI assistant. Engage in natural conversation and answer questions clearly." | |
| else: | |
| if request.response_type == "code": | |
| system_prompt = "You are an expert programmer. Provide clean, efficient code solutions. Always wrap code in markdown code blocks with the appropriate language tag." | |
| elif request.response_type == "explanation": | |
| system_prompt = "You are a programming tutor. Explain programming concepts clearly without providing code. Focus on the approach and logic." | |
| else: # both | |
| system_prompt = "You are an expert programmer and teacher. Provide clear explanations followed by well-commented code examples. Always wrap code in markdown code blocks." | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Add conversation history (last 6 messages to keep context manageable) | |
| if conversation_history[session_id]: | |
| for msg in conversation_history[session_id][-6:]: | |
| messages.append(msg) | |
| # Add current user message | |
| messages.append({"role": "user", "content": request.prompt}) | |
| # Get response from Groq model | |
| print(f"π€ Sending {len(messages)} messages to Groq...") | |
| generated_response = generate_response_groq(messages) | |
| print(f"β Received response of length: {len(generated_response)}") | |
| # Store conversation history in OpenAI message format | |
| conversation_history[session_id].append({"role": "user", "content": request.prompt}) | |
| conversation_history[session_id].append({"role": "assistant", "content": generated_response}) | |
| # Limit history size to prevent memory issues (keep last 20 messages = 10 exchanges) | |
| if len(conversation_history[session_id]) > 20: | |
| conversation_history[session_id] = conversation_history[session_id][-20:] | |
| # For conversational messages, return directly without code/explanation processing | |
| if message_type == "conversation": | |
| response_data = { | |
| "response": generated_response, | |
| "message_type": "conversation" | |
| } | |
| else: | |
| # Handle response type and build response data for code-related messages | |
| response_data = process_response(generated_response, request.response_type) | |
| response_data["message_type"] = "code" | |
| response_data["session_id"] = session_id | |
| return response_data | |
| except HTTPException as e: | |
| # Re-raise HTTP exceptions to maintain status codes | |
| raise | |
| except Exception as e: | |
| print(f"β Unexpected error in generate_response: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") | |
| # API route for clearing conversation history | |
| async def clear_history(request: HistoryRequest): | |
| """Clears conversation history for a given session.""" | |
| if request.session_id in conversation_history: | |
| conversation_history[request.session_id] = [] | |
| return {"status": "success", "message": "Conversation history cleared"} | |
| return {"status": "not_found", "message": "Session ID not found"} | |
| # API route for getting conversation history | |
| async def get_history(request: HistoryRequest): | |
| """Gets conversation history for a given session.""" | |
| if request.session_id in conversation_history: | |
| return { | |
| "status": "success", | |
| "history": conversation_history[request.session_id] | |
| } | |
| return {"status": "not_found", "message": "Session ID not found"} | |
| # Health check endpoint | |
| async def health_check(): | |
| """Health check endpoint to verify the API is running.""" | |
| return { | |
| "status": "ok", | |
| "service": "Groq Code Generation API", | |
| "model": GROQ_MODEL, | |
| "version": "1.0.0" | |
| } | |
| # Request logging middleware for debugging | |
| async def log_requests(request: Request, call_next): | |
| """Log all incoming requests for debugging.""" | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| print(f"π {request.method} {request.url.path} β Status: {response.status_code} ({process_time:.2f}s)") | |
| return response | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) # Hugging Face Spaces uses port 7860 | |
| uvicorn.run(app, host="0.0.0.0", port=port) |