nixaut-codelabs commited on
Commit
13e06f7
·
verified ·
1 Parent(s): e574ca7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends, status
2
+ from fastapi.security import HTTPBearer, HTTPBearerToken
3
+ from pydantic import BaseModel
4
+ from transformers import pipeline
5
+ import gradio as gr
6
+ import os
7
+ from dotenv import load_dotenv
8
+ import uvicorn
9
+ import threading
10
+
11
+ load_dotenv()
12
+
13
+ app = FastAPI(title="AI Prompt Enhancer", version="1.0.0")
14
+ security = HTTPBearer()
15
+
16
+ API_KEY = os.getenv("API_KEY")
17
+ if not API_KEY:
18
+ raise ValueError("API_KEY not found in environment variables")
19
+
20
+ pipe = pipeline("text-generation", model="unsloth/gemma-3-270m-it-GGUF")
21
+
22
+ def load_system_prompt():
23
+ try:
24
+ with open("prompt.txt", "r", encoding="utf-8") as f:
25
+ return f.read().strip()
26
+ except FileNotFoundError:
27
+ return "You are an AI assistant that enhances prompts to make them more effective and detailed."
28
+
29
+ SYSTEM_PROMPT = load_system_prompt()
30
+
31
+ class EnhanceRequest(BaseModel):
32
+ prompt: str
33
+
34
+ class EnhanceResponse(BaseModel):
35
+ enhanced_prompt: str
36
+
37
+ def verify_api_key(token: HTTPBearerToken = Depends(security)):
38
+ if token.credentials != API_KEY:
39
+ raise HTTPException(
40
+ status_code=status.HTTP_401_UNAUTHORIZED,
41
+ detail="Invalid API key"
42
+ )
43
+ return token.credentials
44
+
45
+ @app.post("/enhance", response_model=EnhanceResponse)
46
+ async def enhance_prompt(request: EnhanceRequest, api_key: str = Depends(verify_api_key)):
47
+ messages = [
48
+ {"role": "system", "content": SYSTEM_PROMPT},
49
+ {"role": "user", "content": request.prompt}
50
+ ]
51
+
52
+ try:
53
+ result = pipe(messages, max_length=512, temperature=0.7)
54
+ enhanced_prompt = result[0]["generated_text"]
55
+
56
+ if isinstance(enhanced_prompt, list):
57
+ user_message = next((msg["content"] for msg in enhanced_prompt if msg["role"] == "assistant"), enhanced_prompt[-1]["content"])
58
+ else:
59
+ user_message = enhanced_prompt.split("assistant")[-1].strip() if "assistant" in enhanced_prompt else enhanced_prompt
60
+
61
+ return EnhanceResponse(enhanced_prompt=user_message)
62
+ except Exception as e:
63
+ raise HTTPException(status_code=500, detail=f"Enhancement failed: {str(e)}")
64
+
65
+ def enhance_gradio(prompt_text):
66
+ if not prompt_text.strip():
67
+ return "Please enter a prompt to enhance."
68
+
69
+ messages = [
70
+ {"role": "system", "content": SYSTEM_PROMPT},
71
+ {"role": "user", "content": prompt_text}
72
+ ]
73
+
74
+ try:
75
+ result = pipe(messages, max_length=512, temperature=0.7)
76
+ enhanced_prompt = result[0]["generated_text"]
77
+
78
+ if isinstance(enhanced_prompt, list):
79
+ user_message = next((msg["content"] for msg in enhanced_prompt if msg["role"] == "assistant"), enhanced_prompt[-1]["content"])
80
+ else:
81
+ user_message = enhanced_prompt.split("assistant")[-1].strip() if "assistant" in enhanced_prompt else enhanced_prompt
82
+
83
+ return user_message
84
+ except Exception as e:
85
+ return f"Enhancement failed: {str(e)}"
86
+
87
+ iface = gr.Interface(
88
+ fn=enhance_gradio,
89
+ inputs=gr.Textbox(
90
+ lines=5,
91
+ placeholder="Enter your prompt here to enhance it...",
92
+ label="Original Prompt"
93
+ ),
94
+ outputs=gr.Textbox(
95
+ lines=8,
96
+ label="Enhanced Prompt"
97
+ ),
98
+ title="AI Prompt Enhancer",
99
+ description="Transform your basic prompts into detailed, effective instructions.",
100
+ examples=[
101
+ ["Write a story about a robot"],
102
+ ["Explain machine learning"],
103
+ ["Create a marketing plan"]
104
+ ]
105
+ )
106
+
107
+ def run_gradio():
108
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
109
+
110
+ if __name__ == "__main__":
111
+ gradio_thread = threading.Thread(target=run_gradio, daemon=True)
112
+ gradio_thread.start()
113
+
114
+ uvicorn.run(app, host="0.0.0.0", port=8000)