| from fastapi import FastAPI, Request, HTTPException |
| from fastapi.responses import JSONResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| |
| import json |
| import os |
| from groq import Groq |
|
|
| app = FastAPI() |
|
|
| |
| |
| client = Groq( |
| api_key=os.environ.get("GROQ_API_KEY"), |
| ) |
|
|
| SYSTEM_MESSAGE = ( |
| "You are a helpful, respectful and honest assistant. Always answer as helpfully " |
| "as possible, while being safe. Your answers should not include any harmful, " |
| "unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure " |
| "that your responses are socially unbiased and positive in nature.\n\nIf a question " |
| "does not make any sense, or is not factually coherent, explain why instead of " |
| "answering something not correct. If you don't know the answer to a question, please " |
| "don't share false information." |
| "Always respond in the language of user prompt for each prompt ." |
| ) |
| MAX_TOKENS = 2000 |
| TEMPERATURE = 0.7 |
| TOP_P = 0.95 |
| |
| GROQ_MODEL_NAME = "llama3-8b-8192" |
|
|
| def respond(message, history: list[tuple[str, str]]): |
| messages = [{"role": "system", "content": SYSTEM_MESSAGE}] |
|
|
| for val in history: |
| if val[0]: |
| messages.append({"role": "user", "content": val[0]}) |
| if val[1]: |
| messages.append({"role": "assistant", "content": val[1]}) |
|
|
| messages.append({"role": "user", "content": message}) |
|
|
| |
| |
| response = client.chat.completions.create( |
| messages=messages, |
| model=GROQ_MODEL_NAME, |
| max_tokens=MAX_TOKENS, |
| stream=True, |
| temperature=TEMPERATURE, |
| top_p=TOP_P, |
| ) |
|
|
| |
| for chunk in response: |
| if chunk.choices and chunk.choices[0].delta.content is not None: |
| yield chunk.choices[0].delta.content |
|
|
|
|
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["https://artixiban-ll3.static.hf.space"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.post("/generate/") |
| async def generate(request: Request): |
| allowed_origin = "https://artixiban-ll3.static.hf.space" |
| origin = request.headers.get("origin") |
| if origin != allowed_origin: |
| raise HTTPException(status_code=403, detail="Origin not allowed") |
| form = await request.form() |
| prompt = form.get("prompt") |
| history = json.loads(form.get("history", "[]")) |
|
|
| if not prompt: |
| raise HTTPException(status_code=400, detail="Prompt is required") |
|
|
| response_generator = respond(prompt, history) |
| final_response = "" |
| |
| for part in response_generator: |
| final_response += part |
|
|
| return JSONResponse(content={"response": final_response}) |