Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List | |
| from enhanced_prompt_builder import EnhancedPromptBuilder | |
| from feedback_analyzer import FeedbackAnalyzer | |
| from google import generativeai as genai | |
| from datetime import datetime | |
| import json | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Read Gemini API key from Hugging Face secret | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
| if not GEMINI_API_KEY: | |
| raise RuntimeError("GEMINI_API_KEY not found in environment.") | |
| model = genai.GenerativeModel("gemini-2.5-flash") | |
| def call_gemini(prompt: str) -> str: | |
| """Use Gemini via REST API instead of gRPC-based SDK""" | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key={GEMINI_API_KEY}" | |
| payload = { | |
| "contents": [{"parts": [{"text": prompt}]}] | |
| } | |
| response = requests.post(url, json=payload) | |
| try: | |
| return response.json()["candidates"][0]["content"]["parts"][0]["text"] | |
| except Exception: | |
| raise HTTPException(status_code=500, detail="Error in Gemini response format.") | |
| app = FastAPI() | |
| # CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize enhanced components | |
| enhanced_builder = EnhancedPromptBuilder() | |
| feedback_analyzer = FeedbackAnalyzer() | |
| class AdRequest(BaseModel): | |
| ad_text: str | |
| tone: str | |
| platforms: List[str] | |
| class Feedback(BaseModel): | |
| ad_text: str | |
| tone: str | |
| platforms: List[str] | |
| rewritten_output: str | |
| rating: int # 1 to 5 | |
| def run_enhanced_agent(request: AdRequest): | |
| """Run the agent with enhanced RAG, KG traversal, and adaptive learning""" | |
| try: | |
| # Use enhanced prompt builder | |
| prompt = enhanced_builder.build_adaptive_prompt( | |
| request.ad_text, | |
| request.tone, | |
| request.platforms | |
| ) | |
| # Generate response | |
| response = model.generate_content(prompt) | |
| # Get improvement suggestions | |
| suggestions = enhanced_builder.get_improvement_suggestions() | |
| return { | |
| "rewritten_ads": response.text, | |
| "metadata": { | |
| "used_enhanced_features": True, | |
| "improvement_suggestions": suggestions[:3] # Top 3 suggestions | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def submit_feedback(feedback: Feedback): | |
| entry = { | |
| "timestamp": datetime.now().isoformat(), | |
| "ad_text": feedback.ad_text, | |
| "tone": feedback.tone, | |
| "platforms": feedback.platforms, | |
| "rewritten_output": feedback.rewritten_output, | |
| "rating": feedback.rating | |
| } | |
| try: | |
| with open("feedback_store.json", "r+", encoding="utf-8") as f: | |
| data = json.load(f) | |
| data.append(entry) | |
| f.seek(0) | |
| json.dump(data, f, indent=2) | |
| return {"message": "Feedback submitted successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error storing feedback: {str(e)}") | |
| def get_insights(): | |
| """Get insights from feedback analysis""" | |
| try: | |
| analysis = feedback_analyzer.analyze_patterns() | |
| trends = feedback_analyzer.get_time_based_trends() | |
| weights = feedback_analyzer.get_adaptive_weights() | |
| return { | |
| "analysis_summary": { | |
| "total_feedback": analysis.get("total_feedback", 0), | |
| "average_rating": round(analysis.get("average_rating", 0), 2), | |
| "recommendations": analysis.get("recommendations", [])[:5] | |
| }, | |
| "performance_by_tone": analysis.get("tone_stats", {}), | |
| "performance_by_platform": analysis.get("platform_stats", {}), | |
| "winning_combinations": analysis.get("high_performing_patterns", []), | |
| "needs_improvement": analysis.get("low_performing_patterns", []), | |
| "adaptive_weights": weights, | |
| "recent_trends": trends | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def get_graph_insights(tone: str, platform: str): | |
| """Get knowledge graph insights for a specific tone-platform combination""" | |
| try: | |
| from enhanced_knowledge_graph import EnhancedKnowledgeGraph | |
| kg = EnhancedKnowledgeGraph() | |
| recommendations = kg.get_recommendations(tone, platform) | |
| relationship = kg.explain_relationship(tone, platform) | |
| # Find related nodes | |
| tone_related = kg.traverse_bfs(tone, max_depth=2) | |
| platform_related = kg.traverse_bfs(platform, max_depth=2) | |
| return { | |
| "tone_platform_analysis": { | |
| "tone": tone, | |
| "platform": platform, | |
| "compatibility_score": recommendations["compatibility_score"], | |
| "relationship_explanation": relationship, | |
| "suggestions": recommendations["suggested_elements"], | |
| "warnings": recommendations["warnings"], | |
| "recommended_creative_types": recommendations["creative_types"] | |
| }, | |
| "graph_connections": { | |
| "tone_connections": list(tone_related.keys()), | |
| "platform_connections": list(platform_related.keys()) | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |