Ashrafb commited on
Commit
7807c8e
·
verified ·
1 Parent(s): 86817a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -1
app.py CHANGED
@@ -1,4 +1,108 @@
 
 
 
 
1
  import os
 
 
2
 
 
3
 
4
- exec(os.environ.get('CODE'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import base64
4
+ from typing import Iterator
5
  import os
6
+ from text_generation import Client
7
+ from deep_translator import GoogleTranslator
8
 
9
+ model_id = os.environ.get("CODE", None)
10
 
11
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
12
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
+
14
+ client = Client(
15
+ API_URL,
16
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
17
+ )
18
+ EOS_STRING = "</s>"
19
+ EOT_STRING = "<EOT>"
20
+
21
+ translator_to_en = GoogleTranslator(source='arabic', target='english')
22
+ translator_to_ar = GoogleTranslator(source='english', target='arabic')
23
+
24
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
25
+ system_prompt: str) -> str:
26
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
27
+ do_strip = False
28
+ for user_input, response in chat_history:
29
+ user_input = user_input.strip() if do_strip else user_input
30
+ do_strip = True
31
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
32
+ message = message.strip() if do_strip else message
33
+ texts.append(f'{message} [/INST]')
34
+ return ''.join(texts)
35
+
36
+
37
+ def run(message: str,
38
+ chat_history: list[tuple[str, str]],
39
+ system_prompt: str,
40
+ max_new_tokens: int = 1024,
41
+ temperature: float = 0.1,
42
+ top_p: float = 0.9,
43
+ top_k: int = 50) -> Iterator[str]:
44
+
45
+ prompt = get_prompt(message, chat_history, system_prompt)
46
+
47
+ generate_kwargs = dict(
48
+ max_new_tokens=max_new_tokens,
49
+ do_sample=True,
50
+ top_p=top_p,
51
+ top_k=top_k,
52
+ temperature=temperature,
53
+ )
54
+
55
+ stream = client.generate_stream(prompt, **generate_kwargs)
56
+ output = ""
57
+
58
+ for response in stream:
59
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
60
+ translated_output = translator_to_ar.translate(output)
61
+ yield translated_output
62
+ output = ""
63
+ else:
64
+ output += response.token.text
65
+
66
+
67
+ def generate_image_caption(image_data):
68
+ image_base64 = base64.b64encode(image_data).decode('utf-8')
69
+ payload = {"data": ["data:image/jpeg;base64," + image_base64]}
70
+ response = requests.post("https://ashrafb-salesforce-blip-image-captioning-base.hf.space/run/predict", json=payload)
71
+ if response.status_code == 200:
72
+ caption = response.json()["data"][0]
73
+ return caption
74
+ else:
75
+ return "Error: Unable to generate caption"
76
+
77
+
78
+ def main():
79
+ st.markdown('<p style="color:#191970;text-align:center;font-size:30px;">Aiconvert.online img2story</p>', unsafe_allow_html=True)
80
+
81
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
82
+
83
+ if uploaded_file is not None:
84
+ image_data = uploaded_file.read()
85
+ st.image(image_data, caption="Uploaded Image.", use_column_width=True)
86
+
87
+ if st.button("Generate Story"):
88
+ system_prompt = "write attractive story in 300 words about"
89
+
90
+ if uploaded_file is not None:
91
+ caption = generate_image_caption(image_data)
92
+
93
+ if caption.startswith("Error"):
94
+ st.error(caption)
95
+ return
96
+
97
+ with st.spinner("Generating story..."): # Adding a spinner while generating the story
98
+ ai_response = next(run(caption, [], system_prompt))
99
+
100
+ # Display the generated story
101
+ st.subheader("Generated Story:")
102
+ st.write(ai_response, unsafe_allow_html=True)
103
+ else:
104
+ st.warning("Please upload an image.")
105
+ return
106
+
107
+ if __name__ == "__main__":
108
+ main()