Spaces:
Build error
Build error
File size: 4,252 Bytes
13e06f7 762fe84 13e06f7 821b5d0 13e06f7 821b5d0 13e06f7 821b5d0 13e06f7 762fe84 13e06f7 762fe84 13e06f7 821b5d0 13e06f7 821b5d0 13e06f7 821b5d0 13e06f7 821b5d0 13e06f7 40a687a 13e06f7 40a687a 821b5d0 13e06f7 821b5d0 5240a51 821b5d0 5291883 821b5d0 5291883 821b5d0 9d68c64 7c49eb7 821b5d0 7c49eb7 5240a51 821b5d0 13e06f7 40a687a 78e2a49 13e06f7 78e2a49 6cfb018 13e06f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import HTTPBearer
from pydantic import BaseModel
from llama_cpp import Llama
import gradio as gr
import os
from dotenv import load_dotenv
import uvicorn
import threading
from huggingface_hub import snapshot_download
load_dotenv()
app = FastAPI(title="AI Prompt Enhancer", version="1.0.0")
security = HTTPBearer()
API_KEY = os.getenv("API_KEY")
if not API_KEY:
raise ValueError("API_KEY not found in environment variables")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
snapshot_download(
repo_id="unsloth/gemma-3-270m-it-GGUF",
local_dir="gemma-3-270m-it-GGUF",
allow_patterns=["*UD-Q8_K_XL*"]
)
llm = Llama(
model_path="gemma-3-270m-it-GGUF/gemma-3-270m-it-UD-Q8_K_XL.gguf",
n_ctx=4096,
n_threads=2,
n_gpu_layers=0
)
def load_system_prompt():
try:
with open("prompt.txt", "r", encoding="utf-8") as f:
return f.read().strip()
except FileNotFoundError:
return "You are an AI assistant that enhances prompts to make them more effective and detailed."
SYSTEM_PROMPT = load_system_prompt()
class EnhanceRequest(BaseModel):
prompt: str
class EnhanceResponse(BaseModel):
enhanced_prompt: str
def verify_api_key(credentials = Depends(security)):
if credentials.credentials != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key"
)
return credentials.credentials
@app.post("/enhance", response_model=EnhanceResponse)
async def enhance_prompt(request: EnhanceRequest, api_key: str = Depends(verify_api_key)):
full_prompt = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{request.prompt}<end_of_turn>\n<start_of_turn>model\n"
try:
result = llm(
full_prompt,
max_tokens=512,
temperature=0.7,
top_k=40,
top_p=0.95,
repeat_penalty=1.1,
stop=["<end_of_turn>"]
)
enhanced_prompt = result["choices"][0]["text"].strip()
if not enhanced_prompt:
raise HTTPException(status_code=500, detail="Enhancement failed: Empty response")
return EnhanceResponse(enhanced_prompt=enhanced_prompt)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Enhancement failed: {str(e)}")
def enhance_for_gradio(prompt_text, api_key):
if not prompt_text.strip():
return "Please enter a prompt to enhance."
if not api_key.strip():
return "Please enter your API key."
if api_key != API_KEY:
return "Invalid API key."
full_prompt = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{prompt_text}<end_of_turn>\n<start_of_turn>model\n"
try:
result = llm(
full_prompt,
max_tokens=512,
temperature=1,
top_k=64,
top_p=0.95,
repeat_penalty=1.1,
min_p: 0.01,
repeat_penalty: 1.0,
stop=["<end_of_turn>"]
)
enhanced_prompt = result["choices"][0]["text"].strip()
if not enhanced_prompt:
return "Model generated empty response."
return enhanced_prompt
except Exception as e:
return f"Enhancement failed: {str(e)}"
iface = gr.Interface(
fn=enhance_for_gradio,
inputs=[
gr.Textbox(
lines=5,
placeholder="Enter your prompt here to enhance it...",
label="Original Prompt"
),
gr.Textbox(
placeholder="Enter your API key",
label="API Key",
type="password"
)
],
outputs=gr.Textbox(
lines=8,
label="Enhanced Prompt"
),
title="AI Prompt Enhancer",
description="Transform your basic prompts into detailed, effective instructions. API key required.",
cache_examples=False
)
def run_gradio():
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
if __name__ == "__main__":
gradio_thread = threading.Thread(target=run_gradio, daemon=True)
gradio_thread.start()
uvicorn.run(app, host="0.0.0.0", port=8000) |