lI7Il's picture
Update app.py
1cba00b verified
import gradio as gr
from huggingface_hub import InferenceClient
import evaluate
import psycopg2
from psycopg2.extras import RealDictCursor
from dotenv import load_dotenv
import os
# Load environment variables
load_dotenv()
# Load env vars
# token_hf = os.getenv('ACCESS_TOKEN')
conn_str = os.getenv('NEON_CONN_STR')
# Evaluation metrics
rouge = evaluate.load('rouge')
meteor = evaluate.load('meteor')
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken,
):
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient(token=hf_token.token)
retrieved_info = get_relevant_answers(message, hf_token.token)
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
if retrieved_info:
messages.append({
"role": "user",
"content":f"Answer the following question using the provided context DO NOT mention that you're referring to any context:\nContext: {retrieved_info}\nQuestion: {message}"
})
else:
messages.append({
"role": "user",
"content": message,
})
response = ""
for message in client.chat_completion(
messages,
model='openai/gpt-oss-20b',
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
choices = message.choices
token = ""
if len(choices) and choices[0].delta.content:
token = choices[0].delta.content
response += token
yield response
if retrieved_info:
print_score(response, retrieved_info)
# Print the highest score across references
def print_score(prediction, references):
print(f"{prediction}\n{references}\n\n=== Generation Scores ===")
result = rouge.compute(predictions=[prediction], references=[references])
print("ROUGELSum:", round(result["rougeLsum"], 2))
result = meteor.compute(predictions=[prediction], references=[references])
print("METEOR:", round(result["meteor"], 2))
print(f"{'_'*50}\n")
# Multilangual retrieval
def get_relevant_answers(prompt, token):
embedding_clinet = InferenceClient(
provider="hf-inference",
api_key=token,
)
with psycopg2.connect(conn_str) as conn:
with conn.cursor(cursor_factory=RealDictCursor) as cur:
query_embedding = embedding_clinet.feature_extraction(prompt, model="google/embeddinggemma-300m").tolist()
cur.execute("""
SELECT a, q, q_embeddings_multilang <=> %s::vector AS distance
FROM qa
WHERE q_embeddings_multilang <=> %s::vector < 0.4
ORDER BY distance
LIMIT 3;
""",
(query_embedding, query_embedding))
rows = cur.fetchall()
for row in rows:
print(f"\n{row['a']}\n{row['q']}\n{row['distance']}\n---")
relevant_answers = [x['a'] for x in rows]
print(f"Relevant Answers: {relevant_answers}", end='\n---\n')
return relevant_answers
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(
value="""You are an assistant for question-answering about a planet named Zephyra.
- You have access to external resources which are added along with the user question.
- Use the resources provided along with the user question to answer the question.
- Think hard about whether you can answer the given question using the context that you have or not.
- If the context provided is irrelevant to the user question, answer as if no context is provided.
- If you don't know the answer, just say that you don't know.
- Answer with the same language as the question.
- Keep the answer concise.""",
label="System message"
),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()