File size: 3,928 Bytes
7807c8e
 
 
 
dfe1acb
7807c8e
 
dfe1acb
7807c8e
dfe1acb
7807c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95849dc
5bc2c51
 
 
 
 
 
 
7807c8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import requests
import base64
from typing import Iterator
import os
from text_generation import Client
from deep_translator import GoogleTranslator

model_id = os.environ.get("CODE", None)

API_URL = "https://api-inference.huggingface.co/models/" + model_id
HF_TOKEN = os.environ.get("HF_TOKEN", None)

client = Client(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
EOS_STRING = "</s>"
EOT_STRING = "<EOT>"

translator_to_en = GoogleTranslator(source='arabic', target='english')
translator_to_ar = GoogleTranslator(source='english', target='arabic')

def get_prompt(message: str, chat_history: list[tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
    message = message.strip() if do_strip else message
    texts.append(f'{message} [/INST]')
    return ''.join(texts)


def run(message: str,
                   chat_history: list[tuple[str, str]],
                   system_prompt: str,
                   max_new_tokens: int = 1024,
                   temperature: float = 0.1,
                   top_p: float = 0.9,
                   top_k: int = 50) -> Iterator[str]:

    prompt = get_prompt(message, chat_history, system_prompt)

    generate_kwargs = dict(
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
    )

    stream = client.generate_stream(prompt, **generate_kwargs)
    output = ""
    
    for response in stream:
        if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
            translated_output = translator_to_ar.translate(output)
            yield translated_output
            output = ""
        else:
            output += response.token.text


def generate_image_caption(image_data):
    image_base64 = base64.b64encode(image_data).decode('utf-8')
    payload = {"data": ["data:image/jpeg;base64," + image_base64]}
    response = requests.post("https://ashrafb-salesforce-blip-image-captioning-base.hf.space/run/predict", json=payload)
    if response.status_code == 200:
        caption = response.json()["data"][0]
        return caption
    else:
        return "Error: Unable to generate caption"


def main():
    st.markdown('<p style="color:crimson;text-align:center;font-size:30px;">Aiconvert.online img2story</p>', unsafe_allow_html=True)
    hide_streamlit_style = """
            <style>
            #MainMenu {visibility: hidden;}
            footer {visibility: hidden;}
            </style>
            """
    st.markdown(hide_streamlit_style, unsafe_allow_html=True)
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
    
    if uploaded_file is not None:
        image_data = uploaded_file.read()
        st.image(image_data, caption="Uploaded Image.", use_column_width=True)
    
    if st.button("Generate Story"):
        system_prompt = "write attractive story in 300 words about"
        
        if uploaded_file is not None:
            caption = generate_image_caption(image_data)
            
            if caption.startswith("Error"):
                st.error(caption)
                return

            with st.spinner("Generating story..."):  # Adding a spinner while generating the story
                ai_response = next(run(caption, [], system_prompt))
                
            # Display the generated story
            st.subheader("Generated Story:")
            st.write(ai_response, unsafe_allow_html=True)
        else:
            st.warning("Please upload an image.")
            return

if __name__ == "__main__":
    main()