Spaces:
Runtime error
Runtime error
| import requests | |
| from langchain.chat_models import ChatOpenAI #model server | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import LLMChain | |
| from langchain.prompts import ( | |
| PromptTemplate, | |
| SystemMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| ChatPromptTemplate, | |
| ) | |
| from config import app_config | |
| import mongo_utils as mongo | |
| GROQ_API_KEY = "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" | |
| def __image2text(image): | |
| """Generates a short description of the image""" | |
| headers = {"Authorization": app_config.HF_TOKEN} | |
| try: | |
| response = requests.post(app_config.I2T_API_URL, headers=headers, data=image) | |
| response = response.json()[0]["generated_text"] | |
| except Exception as e: | |
| print(e) | |
| return response | |
| def __text2story(image_desc, genre, style, word_count, creativity): | |
| """ "Generates a short story based on image description text prompt""" | |
| ## chat LLM model | |
| # story_model = ChatOpenAI( | |
| # model="gpt-3.5-turbo", | |
| # openai_api_key=app_config.OPENAI_KEY, | |
| # temperature=creativity, | |
| # ) | |
| story_model = ChatGroq(model="llama3-8b-8192", | |
| temperature=0.0, | |
| api_key=GROQ_API_KEY) | |
| ## chat message prompts | |
| sys_prompt = PromptTemplate( | |
| template="""You are an expert story writer, write a maximum of {word_count} | |
| words long story in {genre} genre in {style} writing style, based on the user | |
| provided story-context. | |
| """, | |
| input_variables=["word_count", "genre", "style"], | |
| ) | |
| system_msg_prompt = SystemMessagePromptTemplate(prompt=sys_prompt) | |
| human_prompt = PromptTemplate( | |
| template="story-context: {context}", input_variables=["context"] | |
| ) | |
| human_msg_prompt = HumanMessagePromptTemplate(prompt=human_prompt) | |
| chat_prompt = ChatPromptTemplate.from_messages( | |
| [system_msg_prompt, human_msg_prompt] | |
| ) | |
| ## LLM chain | |
| story_chain = LLMChain(llm=story_model, prompt=chat_prompt) | |
| response = story_chain.run( | |
| genre=genre, style=style, word_count=word_count, context=image_desc | |
| ) | |
| return response | |
| def generate_story(image_file, genre, style, word_count, creativity): | |
| """Generates a story given an image""" | |
| # read image as bytes arrayS | |
| with open(image_file, "rb") as f: | |
| input_image = f.read() | |
| # generate caption for image | |
| image_desc = __image2text(image=input_image) | |
| print("++++++++++++++++++++++++++++++++++++++") | |
| print(image_desc) | |
| print("++++++++++++++++++++++++++++++++++++++") | |
| # generate story from caption | |
| story = __text2story( | |
| image_desc=image_desc, | |
| genre=genre, | |
| style=style, | |
| word_count=word_count, | |
| creativity=creativity, | |
| ) | |
| # increment the openai access counter and compute count stats | |
| mongo.increment_curr_access_count() | |
| max_count = app_config.openai_max_access_count | |
| curr_count = app_config.openai_curr_access_count | |
| available_count = max_count - curr_count | |
| return story, max_count, curr_count, available_count | |