| | import gradio as gr |
| | import openai |
| | from openai import OpenAI |
| | import google.generativeai as genai |
| | import os |
| | import io |
| | import base64 |
| |
|
| | |
| | api_key = os.environ.get("API_KEY") |
| | |
| |
|
| | |
| | MODEL = os.environ.get("MODEL") |
| | MODEL_NAME = MODEL.split("/")[-1] if "/" in MODEL else MODEL |
| |
|
| | def read(filename): |
| | with open(filename) as f: |
| | data = f.read() |
| | return data |
| | |
| | SYS_PROMPT = read('system_prompt.txt') |
| |
|
| |
|
| | DESCRIPTION = ''' |
| | <div> |
| | <h1 style="text-align: center;">知觉demo</h1> |
| | <p>🩺一个基于提示词和前沿多模态模型的AI,帮助您解读专业领域内容。</p> |
| | <p>🔎 您可以选择领域,参考示例上传图像,或发送需要解读的文字内容。</p> |
| | <p>🦕 生成解读内容仅供参考。</p> |
| | </div> |
| | ''' |
| |
|
| |
|
| | css = """ |
| | h1 { |
| | text-align: center; |
| | display: block; |
| | } |
| | footer { |
| | display:none !important |
| | } |
| | """ |
| |
|
| |
|
| | LICENSE = '采用 ' + MODEL_NAME + ' 模型' |
| |
|
| | def endpoints(api_key): |
| | if api_key is not None: |
| | if api_key.startswith('sk-'): |
| | return 'OPENAI' |
| | else: |
| | return 'GOOGLE' |
| |
|
| | def process_text(text_input, unit): |
| | print(text_input) |
| | endpoint = endpoints(api_key) |
| | if text_input and endpoint == 'OPENAI': |
| | client = OpenAI(api_key=api_key) |
| | completion = client.chat.completions.create( |
| | model=MODEL, |
| | messages=[ |
| | {"role": "system", "content": f" You are a experienced Analyst in {unit}." + SYS_PROMPT}, |
| | {"role": "user", "content": f"Hello! Could you analysis {text_input}?"} |
| | ] |
| | ) |
| | return completion.choices[0].message.content |
| | elif text_input and endpoint == 'GOOGLE': |
| | genai.configure(api_key=api_key) |
| | model = genai.GenerativeModel(model_name=MODEL) |
| | prompt = f" You are a experienced Analyst in {unit}." + SYS_PROMPT + f"Could you analysis {text_input}?" |
| | response = model.generate_content(prompt) |
| | return response.text |
| | return "" |
| |
|
| | def encode_image_to_base64(image_input): |
| | buffered = io.BytesIO() |
| | image_input.save(buffered, format="JPEG") |
| | img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| | return img_str |
| |
|
| | def process_image(image_input, unit): |
| | endpoint = endpoints(api_key) |
| | if image_input is not None and endpoint == 'OPENAI': |
| | |
| | |
| | client = OpenAI(api_key=api_key) |
| | base64_image = encode_image_to_base64(image_input) |
| | response = client.chat.completions.create( |
| | model=MODEL, |
| | messages=[ |
| | {"role": "system", "content": f" You are a experienced Analyst in {unit}." + SYS_PROMPT}, |
| | {"role": "user", "content": [ |
| | {"type": "text", "text": "Help me understand what is in this picture and analysis."}, |
| | {"type": "image_url", |
| | "image_url": { |
| | "url": f"data:image/jpeg;base64,{base64_image}", |
| | "detail":"low"} |
| | } |
| | ]} |
| | ], |
| | temperature=0.0, |
| | max_tokens=1024, |
| | ) |
| | return response.choices[0].message.content |
| | elif image_input is not None and endpoint == 'GOOGLE': |
| | print(image_input) |
| | genai.configure(api_key=api_key) |
| | model = genai.GenerativeModel(model_name=MODEL) |
| | prompt = f" You are a experienced Analyst in {unit}." + SYS_PROMPT + "Help me understand what is in this picture and analysis it." |
| | response = model.generate_content([prompt, image_input],request_options={"timeout": 60}) |
| | return response.text |
| |
|
| |
|
| | def main(text_input="", image_input=None, unit=""): |
| | if text_input and image_input is None: |
| | return process_text(text_input,unit) |
| | elif image_input is not None: |
| | return process_image(image_input,unit) |
| | else: |
| | gr.Error("请输入内容或者上传图片") |
| |
|
| | EXAMPLES = [ |
| | ["./docs/estate.jpeg","",], |
| | ["./docs/pop.jpeg","",], |
| | ["./docs/debt.jpeg","",], |
| | [None,"中国央行表示高度关注当前债券市场变化及潜在风险,必要时会进行卖出低风险债券包括国债操作",], |
| | ] |
| |
|
| | with gr.Blocks(theme='shivi/calm_seafoam', css=css, title="知觉demo") as iface: |
| | with gr.Accordion(""): |
| | gr.Markdown(DESCRIPTION) |
| | unit = gr.Dropdown(label="领域", value='财经', elem_id="units", |
| | choices=["财经", "法律", "政治", "体育", "医疗", \ |
| | "SEO", "评估", "科技", "交通", "行情"]) |
| | with gr.Row(): |
| | output_box = gr.Markdown(label="分析") |
| | with gr.Row(): |
| | image_input = gr.Image(type="pil", label="上传图片") |
| | text_input = gr.Textbox(label="输入") |
| | with gr.Row(): |
| | submit_btn = gr.Button("🚀 确认") |
| | clear_btn = gr.ClearButton([output_box,image_input,text_input], value="🗑️ 清空") |
| |
|
| | |
| | submit_btn.click(main, inputs=[text_input, image_input, unit], outputs=output_box) |
| | gr.Examples(examples=EXAMPLES, inputs=[image_input, text_input]) |
| | gr.Markdown(LICENSE) |
| | |
| | |
| |
|
| | iface.queue().launch(show_api=False) |