Spaces:
Build error
Build error
| from fastapi import FastAPI, HTTPException, Depends, status | |
| from fastapi.security import HTTPBearer | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| import gradio as gr | |
| import os | |
| from dotenv import load_dotenv | |
| import uvicorn | |
| import threading | |
| from huggingface_hub import snapshot_download | |
| load_dotenv() | |
| app = FastAPI(title="AI Prompt Enhancer", version="1.0.0") | |
| security = HTTPBearer() | |
| API_KEY = os.getenv("API_KEY") | |
| if not API_KEY: | |
| raise ValueError("API_KEY not found in environment variables") | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| snapshot_download( | |
| repo_id="unsloth/gemma-3-270m-it-GGUF", | |
| local_dir="gemma-3-270m-it-GGUF", | |
| allow_patterns=["*UD-Q8_K_XL*"] | |
| ) | |
| llm = Llama( | |
| model_path="gemma-3-270m-it-GGUF/gemma-3-270m-it-UD-Q8_K_XL.gguf", | |
| n_ctx=4096, | |
| n_threads=2, | |
| n_gpu_layers=0 | |
| ) | |
| def load_system_prompt(): | |
| try: | |
| with open("prompt.txt", "r", encoding="utf-8") as f: | |
| return f.read().strip() | |
| except FileNotFoundError: | |
| return "You are an AI assistant that enhances prompts to make them more effective and detailed." | |
| SYSTEM_PROMPT = load_system_prompt() | |
| class EnhanceRequest(BaseModel): | |
| prompt: str | |
| class EnhanceResponse(BaseModel): | |
| enhanced_prompt: str | |
| def verify_api_key(credentials = Depends(security)): | |
| if credentials.credentials != API_KEY: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid API key" | |
| ) | |
| return credentials.credentials | |
| async def enhance_prompt(request: EnhanceRequest, api_key: str = Depends(verify_api_key)): | |
| full_prompt = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{request.prompt}<end_of_turn>\n<start_of_turn>model\n" | |
| try: | |
| result = llm( | |
| full_prompt, | |
| max_tokens=512, | |
| temperature=0.7, | |
| top_k=40, | |
| top_p=0.95, | |
| repeat_penalty=1.1, | |
| stop=["<end_of_turn>"] | |
| ) | |
| enhanced_prompt = result["choices"][0]["text"].strip() | |
| if not enhanced_prompt: | |
| raise HTTPException(status_code=500, detail="Enhancement failed: Empty response") | |
| return EnhanceResponse(enhanced_prompt=enhanced_prompt) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Enhancement failed: {str(e)}") | |
| def enhance_for_gradio(prompt_text, api_key): | |
| if not prompt_text.strip(): | |
| return "Please enter a prompt to enhance." | |
| if not api_key.strip(): | |
| return "Please enter your API key." | |
| if api_key != API_KEY: | |
| return "Invalid API key." | |
| full_prompt = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{prompt_text}<end_of_turn>\n<start_of_turn>model\n" | |
| try: | |
| result = llm( | |
| full_prompt, | |
| max_tokens=512, | |
| temperature=1, | |
| top_k=64, | |
| top_p=0.95, | |
| repeat_penalty=1.1, | |
| min_p: 0.01, | |
| repeat_penalty: 1.0, | |
| stop=["<end_of_turn>"] | |
| ) | |
| enhanced_prompt = result["choices"][0]["text"].strip() | |
| if not enhanced_prompt: | |
| return "Model generated empty response." | |
| return enhanced_prompt | |
| except Exception as e: | |
| return f"Enhancement failed: {str(e)}" | |
| iface = gr.Interface( | |
| fn=enhance_for_gradio, | |
| inputs=[ | |
| gr.Textbox( | |
| lines=5, | |
| placeholder="Enter your prompt here to enhance it...", | |
| label="Original Prompt" | |
| ), | |
| gr.Textbox( | |
| placeholder="Enter your API key", | |
| label="API Key", | |
| type="password" | |
| ) | |
| ], | |
| outputs=gr.Textbox( | |
| lines=8, | |
| label="Enhanced Prompt" | |
| ), | |
| title="AI Prompt Enhancer", | |
| description="Transform your basic prompts into detailed, effective instructions. API key required.", | |
| cache_examples=False | |
| ) | |
| def run_gradio(): | |
| iface.launch(server_name="0.0.0.0", server_port=7860, share=False) | |
| if __name__ == "__main__": | |
| gradio_thread = threading.Thread(target=run_gradio, daemon=True) | |
| gradio_thread.start() | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |