File size: 13,943 Bytes
59da258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import os
import time
import uuid
import requests
import re
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware
from functools import lru_cache
from typing import Optional, Dict, Any, List
from dotenv import load_dotenv

# Load .env automatically from the project directory
load_dotenv()

# Read API key from environment
GROQ_API_KEY = os.getenv("GROQ_API_KEY")

# Hardcoded configuration
GROQ_MODEL = "moonshotai/kimi-k2-instruct-0905"  # Default Groq model
MAX_TOKENS = 2000
TEMPERATURE = 0.5

# Debugging: Check if API key is loaded
if not GROQ_API_KEY:
    print("❌ GROQ_API_KEY is not set. Check your .env file or environment variables.")
else:
    print(f"βœ… GROQ_API_KEY Loaded: {GROQ_API_KEY[:10]}******")  # Masked for security

print(f"πŸ“¦ GROQ_MODEL Loaded: {GROQ_MODEL}")
print(f"βš™οΈ Using parameters: MAX_TOKENS={MAX_TOKENS}, TEMPERATURE={TEMPERATURE}")

# Initialize FastAPI app
app = FastAPI(
    title="Code Generation API with Groq",
    description="API for generating code and explanations using Groq's LLM models",
    version="1.0.0"
)

# Enable CORS for frontend communication
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Update this with frontend domain in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# In-memory conversation history (use Redis/DB for production)
conversation_history: Dict[str, List[Dict[str, str]]] = {}


# Define request formats
class PromptRequest(BaseModel):
    prompt: str = Field(..., description="The user's prompt or question")
    session_id: Optional[str] = Field(None, description="Session ID for conversation history")
    response_type: Optional[str] = Field("both", description="Type of response: 'code', 'explanation', or 'both'")


class HistoryRequest(BaseModel):
    session_id: str = Field(..., description="Session ID to retrieve or clear history")


def classify_message(message: str) -> str:
    """Classify whether the message is conversational or code-related."""
    
    # Convert message to lowercase for comparison
    message_lower = message.lower().strip()

    # List of common conversational greetings and phrases
    conversational_phrases = [
        "hi", "hello", "hey", "hi there", "hello there", "hey there",
        "how are you", "good morning", "good afternoon", "good evening",
        "what's up", "how's it going", "nice to meet you", "bye", "goodbye",
        "thank you", "thanks", "ok", "okay", "yes", "no", "maybe",
        "help", "who are you", "what can you do", "what are you",
        "tell me about yourself"
    ]

    # Check if the message is a question or conversation
    if any(message_lower.startswith(phrase) for phrase in conversational_phrases) or \
            any(phrase in message_lower for phrase in conversational_phrases[:10]) or \
            (message_lower.endswith("?") and len(message_lower.split()) <= 8):
        return "conversation"

    # Check for code-related keywords
    code_keywords = ["code", "function", "script", "program", "algorithm", "implement", 
                     "write", "create", "python", "javascript", "java", "c++"]
    
    if any(keyword in message_lower for keyword in code_keywords):
        return "code"

    # If in doubt, treat as conversation
    return "conversation"


# API call function with retry and improved error handling
def generate_response_groq(messages: List[Dict[str, str]]) -> str:
    """Sends messages to Groq API and returns the generated response."""
    if not GROQ_API_KEY:
        raise HTTPException(status_code=500, detail="GROQ_API_KEY is missing.")

    url = "https://api.groq.com/openai/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {GROQ_API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": GROQ_MODEL,
        "messages": messages,
        "temperature": TEMPERATURE,
        "max_tokens": MAX_TOKENS,
    }

    for attempt in range(3):  # Retry logic
        try:
            print(f"πŸ”„ Attempt {attempt + 1} - Sending request to Groq API")
            response = requests.post(url, headers=headers, json=payload, timeout=60)
            print(f"πŸ“Š Status Code: {response.status_code}")

            if response.status_code == 200:
                result = response.json()
                if "choices" in result and len(result["choices"]) > 0:
                    generated_text = result["choices"][0]["message"]["content"]
                    return generated_text
                return "No response generated"

            elif response.status_code == 401:  # Unauthorized (Invalid API key)
                print("❌ Authentication error: Invalid API Key")
                raise HTTPException(status_code=401, detail="Invalid API Key. Check your GROQ_API_KEY.")

            elif response.status_code == 429:  # Rate limit error
                print("⚠️ Rate limited, retrying...")
                time.sleep(2 ** attempt)  # Exponential backoff
                continue

            elif response.status_code == 503:  # Service unavailable
                print("⚠️ Service unavailable, retrying...")
                time.sleep(2 ** attempt)
                continue

            else:
                error_detail = "Unknown error"
                try:
                    error_data = response.json()
                    error_detail = error_data.get("error", {}).get("message", str(error_data))
                except:
                    error_detail = response.text

                print(f"❌ API Error: {error_detail}")
                if attempt == 2:  # Last attempt
                    raise HTTPException(status_code=response.status_code,
                                        detail=f"Groq API Error: {error_detail}")

        except requests.exceptions.Timeout:
            print("⚠️ Request timed out, retrying...")
            if attempt == 2:  # Last attempt
                raise HTTPException(status_code=504, detail="Request timed out")

        except requests.exceptions.ConnectionError:
            print("⚠️ Connection error, retrying...")
            if attempt == 2:  # Last attempt
                raise HTTPException(status_code=503, detail="Could not connect to Groq API")

        except Exception as e:
            print(f"❌ Unexpected error: {str(e)}")
            if attempt == 2:  # Last attempt
                raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")

        # Wait before retry (except on last attempt)
        if attempt < 2:
            time.sleep(2 ** attempt)

    raise HTTPException(status_code=500, detail="Failed to get response after multiple attempts")


# Helper function to process and format the model's response
def process_response(raw_response: str, response_type: str) -> Dict[str, Any]:
    """Process and format the model's response based on the requested type."""

    # For conversational responses, don't try to extract code
    if response_type == "conversation":
        return {"response": raw_response}

    elif response_type == "code":
        # Extract code blocks with regex
        code_match = re.search(r"```(?:python|javascript|java|cpp|c\+\+)?\n(.*?)\n```", raw_response, re.DOTALL)
        if code_match:
            return {"generated_code": code_match.group(1).strip()}
        # If no code block found, return the whole response as code
        return {"generated_code": raw_response}

    elif response_type == "explanation":
        # Remove code blocks
        explanation = re.sub(r"```(?:\w+)?\n.*?\n```", "", raw_response, flags=re.DOTALL).strip()
        return {"explanation": explanation}

    else:  # "both"
        code = None
        explanation = raw_response

        # Extract code blocks
        code_match = re.search(r"```(?:python|javascript|java|cpp|c\+\+)?\n(.*?)\n```", raw_response, re.DOTALL)
        if code_match:
            code = code_match.group(1).strip()
            # Remove code blocks from explanation
            explanation = re.sub(r"```(?:\w+)?\n.*?\n```", "", raw_response, flags=re.DOTALL).strip()

        return {
            "response": raw_response,
            "generated_code": code,
            "explanation": explanation
        }


# API route for generating responses
@app.post("/generate/")
async def generate_response(request: PromptRequest):
    """Handles incoming user requests, maintains session history, and calls Groq model."""
    try:
        session_id = request.session_id or str(uuid.uuid4())

        if session_id not in conversation_history:
            conversation_history[session_id] = []

        # Classify the message type first
        message_type = classify_message(request.prompt)

        # Build messages array for Groq API (OpenAI format)
        messages = []
        
        # Add system message based on response type
        if message_type == "conversation":
            system_prompt = "You are a helpful and friendly AI assistant. Engage in natural conversation and answer questions clearly."
        else:
            if request.response_type == "code":
                system_prompt = "You are an expert programmer. Provide clean, efficient code solutions. Always wrap code in markdown code blocks with the appropriate language tag."
            elif request.response_type == "explanation":
                system_prompt = "You are a programming tutor. Explain programming concepts clearly without providing code. Focus on the approach and logic."
            else:  # both
                system_prompt = "You are an expert programmer and teacher. Provide clear explanations followed by well-commented code examples. Always wrap code in markdown code blocks."
        
        messages.append({"role": "system", "content": system_prompt})
        
        # Add conversation history (last 6 messages to keep context manageable)
        if conversation_history[session_id]:
            for msg in conversation_history[session_id][-6:]:
                messages.append(msg)
        
        # Add current user message
        messages.append({"role": "user", "content": request.prompt})

        # Get response from Groq model
        print(f"πŸ“€ Sending {len(messages)} messages to Groq...")
        generated_response = generate_response_groq(messages)
        print(f"βœ… Received response of length: {len(generated_response)}")

        # Store conversation history in OpenAI message format
        conversation_history[session_id].append({"role": "user", "content": request.prompt})
        conversation_history[session_id].append({"role": "assistant", "content": generated_response})

        # Limit history size to prevent memory issues (keep last 20 messages = 10 exchanges)
        if len(conversation_history[session_id]) > 20:
            conversation_history[session_id] = conversation_history[session_id][-20:]

        # For conversational messages, return directly without code/explanation processing
        if message_type == "conversation":
            response_data = {
                "response": generated_response, 
                "message_type": "conversation"
            }
        else:
            # Handle response type and build response data for code-related messages
            response_data = process_response(generated_response, request.response_type)
            response_data["message_type"] = "code"

        response_data["session_id"] = session_id
        return response_data

    except HTTPException as e:
        # Re-raise HTTP exceptions to maintain status codes
        raise
    except Exception as e:
        print(f"❌ Unexpected error in generate_response: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")


# API route for clearing conversation history
@app.post("/clear_history/")
async def clear_history(request: HistoryRequest):
    """Clears conversation history for a given session."""
    if request.session_id in conversation_history:
        conversation_history[request.session_id] = []
        return {"status": "success", "message": "Conversation history cleared"}
    return {"status": "not_found", "message": "Session ID not found"}


# API route for getting conversation history
@app.post("/get_history/")
async def get_history(request: HistoryRequest):
    """Gets conversation history for a given session."""
    if request.session_id in conversation_history:
        return {
            "status": "success", 
            "history": conversation_history[request.session_id]
        }
    return {"status": "not_found", "message": "Session ID not found"}


# Health check endpoint
@app.get("/")
@app.get("/health")
async def health_check():
    """Health check endpoint to verify the API is running."""
    return {
        "status": "ok", 
        "service": "Groq Code Generation API",
        "model": GROQ_MODEL,
        "version": "1.0.0"
    }


# Request logging middleware for debugging
@app.middleware("http")
async def log_requests(request: Request, call_next):
    """Log all incoming requests for debugging."""
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    print(f"πŸ“ {request.method} {request.url.path} β†’ Status: {response.status_code} ({process_time:.2f}s)")
    return response


if __name__ == "__main__":
    import uvicorn
    port = int(os.getenv("PORT", "7860"))  # Hugging Face Spaces uses port 7860
    uvicorn.run(app, host="0.0.0.0", port=port)