from pydantic import BaseModel from .ConfigEnv import config from fastapi.middleware.cors import CORSMiddleware from langchain.llms import Clarifai from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from TextGen import app import requests class Generate(BaseModel): text:str def generate_text(prompt: str): if prompt == "": return {"detail": "Please provide a prompt."} else: prompt = PromptTemplate(template=prompt, input_variables=['Prompt']) llm = Clarifai( pat = config.CLARIFAI_PAT, user_id = config.USER_ID, app_id = config.APP_ID, model_id = config.MODEL_ID, model_version_id=config.MODEL_VERSION_ID, ) llmchain = LLMChain( prompt=prompt, llm=llm ) llm_response = llmchain.run({"Prompt": prompt}) return Generate(text=llm_response) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", tags=["Home"]) def api_home(): return {'detail': 'Welcome to FastAPI TextGen Tutorial!'} API_URL = "https://api-inference.huggingface.co/models/google/gemma-1.1-7b-it" headers = {"Authorization": f"Bearer {config.HF_TOKEN}"} def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() @app.get("/query", tags=["Query"]) def api_query(input: str): return query({"inputs": input,"parameters":{"return_full_text":False,"max_length":1024}}) @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate) def inference(input_prompt: str): return generate_text(prompt=input_prompt)