TitleMatcher / api_call.py
alireza-2003's picture
Update api_call.py
a12e4e9 verified
import os
import time
import random
from httpx import HTTPStatusError
from langchain.prompts import ChatPromptTemplate
from dotenv import load_dotenv
from langchain_mistralai import ChatMistralAI
from langchain_core.output_parsers import StrOutputParser
class Chain:
def __init__(self, model="mistral-large-latest"):
# Load environment variables
load_dotenv()
self.mistral_api_key = os.getenv("MISTRAL_API_KEY")
self.model = model
# Initialize model and parser
self.model = ChatMistralAI(model=self.model, api_key=self.mistral_api_key)
self.parser = StrOutputParser()
# Define template
self._template_for_sentiment = """
You are an expert model for measuring the similarity between titles. Given two titles, please provide a score between 0 and 1, where 0 means no similarity and 1 means identical similarity. Here are some examples:
Examples: sentence 1: "Bowl of Strawberry Ice Cream"
sentence 2: "Strawberry Ice Cream Cone"
Similarity score: 0.9
sentence 1: "Digital Camera 1080p"
sentence 2: "High-definition Camera"
Similarity score: 0.8
sentence 1: "Electric Toothpaste"
sentence 2: "Manual Toothbrush"
Similarity score: 0.2
Now, measure the similarity between the following pairs of titles(just response one decimal number between 0 and 1):
{titles}
"""
self._initialize_chains()
def _initialize_chains(self):
"""Initialize the prompt chains"""
self.prompt_single = ChatPromptTemplate.from_template(self._template_for_sentiment)
self.chain_single = self.prompt_single | self.model | self.parser
def _classify_with_backoff(self, titles, max_retries=10):
"""
Comparison of two Tiletels with retry mechanism for rate limiting
"""
retries = 0
while retries < max_retries:
try:
response = self.chain_single.invoke({"titles": titles})
return response.strip()
except HTTPStatusError as e:
if e.response.status_code == 429:
wait_time = 2 ** retries + random.uniform(0, 1)
print(f"Rate limit exceeded. Retrying in {wait_time:.2f} seconds...")
time.sleep(wait_time)
retries += 1
else:
raise
return "Error"
def classify(self, title_0, title_1, max_retries=10):
return self._classify_with_backoff(f"{title_0} and {title_1}", max_retries)
if __name__ == "__main__":
analyzer = Chain()
print(analyzer.classify("بستنی چوبی ",'بستنی خنک'))
print(analyzer.classify("بستنی چوبی ",'Ice cream'))