File size: 1,803 Bytes
9ff5556
 
 
 
 
 
 
 
 
 
 
81e26a7
 
9ff5556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98d2562
 
9978c21
98d2562
 
9978c21
98d2562
 
 
 
1943f0d
98d2562
9ff5556
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from pydantic import BaseModel

from .ConfigEnv import config
from fastapi.middleware.cors import CORSMiddleware

from langchain.llms import Clarifai
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

from TextGen import app

import requests

class Generate(BaseModel):
    text:str

def generate_text(prompt: str):
    if prompt == "":
        return {"detail": "Please provide a prompt."}
    else:
        prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])

        llm = Clarifai(
            pat = config.CLARIFAI_PAT,
            user_id = config.USER_ID,
            app_id = config.APP_ID, 
            model_id = config.MODEL_ID,
            model_version_id=config.MODEL_VERSION_ID,
        )

        llmchain = LLMChain(
            prompt=prompt,
            llm=llm
        )

        llm_response = llmchain.run({"Prompt": prompt})
        return Generate(text=llm_response)

        

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/", tags=["Home"])
def api_home():
    return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}


API_URL = "https://api-inference.huggingface.co/models/google/gemma-1.1-7b-it"
headers = {"Authorization": f"Bearer {config.HF_TOKEN}"}

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()
    
@app.get("/query", tags=["Query"])
def api_query(input: str):
    return query({"inputs": input,"parameters":{"return_full_text":False,"max_length":1024}})

@app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
def inference(input_prompt: str):
    return generate_text(prompt=input_prompt)