nixaut-codelabs's picture
Update app.py
5291883 verified
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import HTTPBearer
from pydantic import BaseModel
from llama_cpp import Llama
import gradio as gr
import os
from dotenv import load_dotenv
import uvicorn
import threading
from huggingface_hub import snapshot_download
load_dotenv()
app = FastAPI(title="AI Prompt Enhancer", version="1.0.0")
security = HTTPBearer()
API_KEY = os.getenv("API_KEY")
if not API_KEY:
raise ValueError("API_KEY not found in environment variables")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
snapshot_download(
repo_id="unsloth/gemma-3-270m-it-GGUF",
local_dir="gemma-3-270m-it-GGUF",
allow_patterns=["*UD-Q8_K_XL*"]
)
llm = Llama(
model_path="gemma-3-270m-it-GGUF/gemma-3-270m-it-UD-Q8_K_XL.gguf",
n_ctx=4096,
n_threads=2,
n_gpu_layers=0
)
def load_system_prompt():
try:
with open("prompt.txt", "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
return "You are an AI assistant that enhances prompts to make them more effective and detailed."
SYSTEM_PROMPT = load_system_prompt()
class EnhanceRequest(BaseModel):
prompt: str
class EnhanceResponse(BaseModel):
enhanced_prompt: str
def verify_api_key(credentials = Depends(security)):
if credentials.credentials != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
return credentials.credentials
@app.post("/enhance", response_model=EnhanceResponse)
async def enhance_prompt(request: EnhanceRequest, api_key: str = Depends(verify_api_key)):
full_prompt = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{request.prompt}<end_of_turn>\n<start_of_turn>model\n"
try:
result = llm(
full_prompt,
max_tokens=512,
temperature=0.7,
top_k=40,
top_p=0.95,
repeat_penalty=1.1,
stop=["<end_of_turn>"]
)
enhanced_prompt = result["choices"][0]["text"].strip()
if not enhanced_prompt:
raise HTTPException(status_code=500, detail="Enhancement failed: Empty response")
return EnhanceResponse(enhanced_prompt=enhanced_prompt)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Enhancement failed: {str(e)}")
def enhance_for_gradio(prompt_text, api_key):
if not prompt_text.strip():
return "Please enter a prompt to enhance."
if not api_key.strip():
return "Please enter your API key."
if api_key != API_KEY:
return "Invalid API key."
full_prompt = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{prompt_text}<end_of_turn>\n<start_of_turn>model\n"
try:
result = llm(
full_prompt,
max_tokens=512,
temperature=1,
top_k=64,
top_p=0.95,
repeat_penalty=1.1,
min_p: 0.01,
repeat_penalty: 1.0,
stop=["<end_of_turn>"]
)
enhanced_prompt = result["choices"][0]["text"].strip()
if not enhanced_prompt:
return "Model generated empty response."
return enhanced_prompt
except Exception as e:
return f"Enhancement failed: {str(e)}"
iface = gr.Interface(
fn=enhance_for_gradio,
inputs=[
gr.Textbox(
lines=5,
placeholder="Enter your prompt here to enhance it...",
label="Original Prompt"
),
gr.Textbox(
placeholder="Enter your API key",
label="API Key",
type="password"
)
],
outputs=gr.Textbox(
lines=8,
label="Enhanced Prompt"
),
title="AI Prompt Enhancer",
description="Transform your basic prompts into detailed, effective instructions. API key required.",
cache_examples=False
)
def run_gradio():
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
if __name__ == "__main__":
gradio_thread = threading.Thread(target=run_gradio, daemon=True)
gradio_thread.start()
uvicorn.run(app, host="0.0.0.0", port=8000)