Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_dataset, Dataset | |
| from llama_index.core import PromptTemplate | |
| from llama_index.core.prompts import ChatMessage | |
| from llama_index.llms.openai import OpenAI | |
| from pydantic import BaseModel, Field | |
| import asyncio | |
| import numpy as np | |
| import pandas as pd | |
| from chromadb import Client | |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
| import structlog | |
| logger = structlog.get_logger() | |
| logger.info('Loading embedding model') | |
| embed_fn = SentenceTransformerEmbeddingFunction('BAAI/bge-small-en-v1.5') | |
| def load_train_data_and_vectorstore(): | |
| logger.info("Loading dataset") | |
| ds = load_dataset('SetFit/amazon_reviews_multi_en') | |
| train_samples_per_class = 50 | |
| eval_test_samples_per_class = 10 | |
| train = Dataset.from_pandas(ds['train'].to_pandas().groupby('label').sample(train_samples_per_class, random_state=1234).reset_index(drop=True)) | |
| reviews = Client().create_collection( | |
| name='reviews', | |
| embedding_function=embed_fn, | |
| get_or_create=True | |
| ) | |
| logger.info("Adding documents to vector store") | |
| reviews.add(documents=train['text'], metadatas=[{'rating': x} for x in train['label']], ids=train['id']) | |
| return train, reviews | |
| train, reviews = load_train_data_and_vectorstore() | |
| class Rating(BaseModel): | |
| rating: int = Field(..., description="Rating of the review", enum=[0, 1, 2, 3, 4]) | |
| llm = OpenAI(model="gpt-4o-mini") | |
| structured_llm = llm.as_structured_llm(Rating) | |
| prompt_tmpl_str = """\ | |
| The review text is below. | |
| --------------------- | |
| {review} | |
| --------------------- | |
| Given the review text and not prior knowledge, \ | |
| please attempt to predict the score of the review. | |
| Query: What is the rating of this review? | |
| Answer: \ | |
| """ | |
| prompt_tmpl = PromptTemplate( | |
| prompt_tmpl_str, | |
| ) | |
| async def zero_shot_predict(text): | |
| messages = [ | |
| ChatMessage.from_str(prompt_tmpl.format(review=text)) | |
| ] | |
| response = await structured_llm.achat(messages) | |
| return response.raw.rating | |
| rng = np.random.Generator(np.random.PCG64(1234)) | |
| def random_few_shot_examples_fn(**kwargs): | |
| if n_samples:=kwargs.get('n_samples'): | |
| random_examples = train.shuffle(generator=rng)[:n_samples] | |
| else: | |
| random_examples = train.shuffle(generator=rng)[:5] | |
| result_strs = [] | |
| for text, rating in zip(random_examples['text'], random_examples['label']): | |
| result_strs.append(f"Text: {text}\nRating: {rating}") | |
| return "\n\n".join(result_strs) | |
| few_shot_prompt_tmpl_str = """\ | |
| The review text is below. | |
| --------------------- | |
| {review} | |
| --------------------- | |
| Given the review text and not prior knowledge, \ | |
| please attempt to predict the review score of the context. \ | |
| Here are several examples of reviews and their ratings: | |
| {random_few_shot_examples} | |
| Query: What is the rating of this review? | |
| Answer: \ | |
| """ | |
| few_shot_prompt_tmpl = PromptTemplate( | |
| few_shot_prompt_tmpl_str, | |
| function_mappings={"random_few_shot_examples": random_few_shot_examples_fn}, | |
| ) | |
| async def random_few_shot_predict(text, n_examples=5): | |
| tasks = [] | |
| for _ in range(3): | |
| messages = [ | |
| ChatMessage.from_str(few_shot_prompt_tmpl.format(review=text, n_samples=n_examples)) | |
| ] | |
| tasks.append(structured_llm.achat(messages, temperature=0.9)) | |
| results = await asyncio.gather(*tasks) | |
| ratings = [r.raw.rating for r in results] | |
| # print(ratings) | |
| return pd.Series(ratings).mode()[0] | |
| def dynamic_few_shot_examples_fn(**kwargs): | |
| n_examples = kwargs.get('n_examples', 5) | |
| retrievals = reviews.query( | |
| query_texts=[kwargs['review']], | |
| n_results=n_examples | |
| ) | |
| result_strs = [] | |
| documents = retrievals['documents'][0] | |
| metadatas = retrievals['metadatas'][0] | |
| for document, metadata in zip(documents, metadatas): | |
| result_strs.append(f"Text: {document}\nRating: {metadata.get('rating')}") | |
| return "\n\n".join(result_strs) | |
| dynamic_few_shot_prompt_tmpl_str = """\ | |
| The review text is below. | |
| --------------------- | |
| {review} | |
| --------------------- | |
| Given the review text and not prior knowledge, \ | |
| please attempt to predict the review score of the context. \ | |
| Here are several examples of reviews and their ratings: | |
| {dynamic_few_shot_examples} | |
| Query: What is the rating of this review? | |
| Answer: \ | |
| """ | |
| dynamic_few_shot_prompt_tmpl = PromptTemplate( | |
| dynamic_few_shot_prompt_tmpl_str, | |
| function_mappings={"dynamic_few_shot_examples": dynamic_few_shot_examples_fn}, | |
| ) | |
| async def dynamic_few_shot_predict(text, n_examples=5): | |
| messages = [ | |
| ChatMessage.from_str(dynamic_few_shot_prompt_tmpl.format(review=text, n_examples=n_examples)) | |
| ] | |
| response = await structured_llm.achat(messages) | |
| return response.raw.rating | |
| def classify(review, num_examples, api_key): | |
| llm = OpenAI(model="gpt-4o-mini", api_key=api_key).as_structured_llm(Rating) | |
| zero_shot = asyncio.run(zero_shot_predict(review)) | |
| random_few_shot = asyncio.run(random_few_shot_predict(review, num_examples)) | |
| dynamic_few_shot = asyncio.run(dynamic_few_shot_predict(review, num_examples)) | |
| return zero_shot, random_few_shot, dynamic_few_shot | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| api_key = gr.Textbox(label='Openai API Key') | |
| n_examples = gr.Slider(minimum=1, maximum=10, value=5, step=1, label='Number of examples to retrieve', interactive=True) | |
| review = gr.Textbox(label='Review', interactive=True) | |
| submit = gr.Button(value='Submit') | |
| with gr.Column(): | |
| zero_shot_label = gr.Textbox(label='Zero shot', interactive=False) | |
| random_few_shot_label = gr.Textbox(label='Random few shot', interactive=False) | |
| dynamic_few_shot_label = gr.Textbox(label='Dynamic few shot', interactive=False) | |
| submit.click(classify, [review, n_examples], [zero_shot_label, random_few_shot_label, dynamic_few_shot_label]) | |
| demo.queue().launch() |