Ashrafb commited on
Commit
dfe1acb
·
1 Parent(s): 73b61c3

Upload app (16).py

Browse files
Files changed (1) hide show
  1. app (16).py +100 -0
app (16).py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import base64
4
+ from typing import Iterator
5
+ import os
6
+ from text_generation import Client
7
+
8
+ model_id = 'codellama/CodeLlama-34b-Instruct-hf'
9
+
10
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
11
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
12
+
13
+ client = Client(
14
+ API_URL,
15
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
16
+ )
17
+ EOS_STRING = "</s>"
18
+ EOT_STRING = "<EOT>"
19
+
20
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
21
+ system_prompt: str) -> str:
22
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
23
+ do_strip = False
24
+ for user_input, response in chat_history:
25
+ user_input = user_input.strip() if do_strip else user_input
26
+ do_strip = True
27
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
28
+ message = message.strip() if do_strip else message
29
+ texts.append(f'{message} [/INST]')
30
+ return ''.join(texts)
31
+
32
+
33
+ def run(message: str,
34
+ chat_history: list[tuple[str, str]],
35
+ system_prompt: str,
36
+ max_new_tokens: int = 1024,
37
+ temperature: float = 0.1,
38
+ top_p: float = 0.9,
39
+ top_k: int = 50) -> Iterator[str]:
40
+ prompt = get_prompt(message, chat_history, system_prompt)
41
+
42
+ generate_kwargs = dict(
43
+ max_new_tokens=max_new_tokens,
44
+ do_sample=True,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ temperature=temperature,
48
+ )
49
+ stream = client.generate_stream(prompt, **generate_kwargs)
50
+ output = ""
51
+ for response in stream:
52
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
53
+ yield output
54
+ output = ""
55
+ else:
56
+ output += response.token.text
57
+
58
+
59
+ def generate_image_caption(image_data):
60
+ image_base64 = base64.b64encode(image_data).decode('utf-8')
61
+ payload = {"data": ["data:image/jpeg;base64," + image_base64]}
62
+ response = requests.post("https://ashrafb-salesforce-blip-image-captioning-base.hf.space/run/predict", json=payload)
63
+ if response.status_code == 200:
64
+ caption = response.json()["data"][0]
65
+ return caption
66
+ else:
67
+ return "Error: Unable to generate caption"
68
+
69
+
70
+ def main():
71
+ st.title("AIconvert img2story")
72
+
73
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
74
+
75
+ if uploaded_file is not None:
76
+ image_data = uploaded_file.read()
77
+ st.image(image_data, caption="Uploaded Image.", use_column_width=True)
78
+
79
+ if st.button("Generate Story"):
80
+ system_prompt = "write story in 100 words about"
81
+
82
+ if uploaded_file is not None:
83
+ caption = generate_image_caption(image_data)
84
+
85
+ if caption.startswith("Error"):
86
+ st.error(caption)
87
+ return
88
+
89
+ ai_response = next(run(caption, [], system_prompt))
90
+ else:
91
+ st.warning("Please upload an image.")
92
+ return
93
+
94
+ # Display the generated story
95
+ st.subheader("Generated Story:")
96
+ st.write(ai_response, unsafe_allow_html=True)
97
+
98
+
99
+ if __name__ == "__main__":
100
+ main()