File size: 5,213 Bytes
720ed94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944f540
720ed94
 
 
 
 
 
 
 
 
944f540
720ed94
29ec62a
 
720ed94
ec59be7
720ed94
 
 
 
 
 
 
 
 
 
 
29ec62a
720ed94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cf6d01
720ed94
2fbc411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720ed94
 
 
 
 
 
 
 
 
 
 
2fbc411
 
 
 
 
 
 
3cf6d01
720ed94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bcd0c
bdf89c3
 
 
 
 
 
 
 
b0bcd0c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import getpass
from groq import Groq
from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain import hub
from langgraph.graph import START, StateGraph
from pydantic.main import BaseModel
from typing_extensions import List, TypedDict

from langchain_cohere import CohereEmbeddings

import re
# from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

'''
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
'''

# load_dotenv()

# print(f"GROQ_API_KEY: {os.getenv('GROQ_API_KEY')}")
# print(f"HUGGING_FACE_API_KEY: {os.getenv('HUGGING_FACE_API_KEY')}")

llm = init_chat_model("qwen-qwq-32b", model_provider="groq", api_key=os.environ["GROQ_API_KEY"])
'''
embeddings = HuggingFaceInferenceAPIEmbeddings(
    api_key = os.getenv('HUGGING_FACE_API_KEY'),
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

embeddings = HuggingFaceInferenceAPIEmbeddings(
    api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2"
)'''

embeddings = CohereEmbeddings(
    cohere_api_key=os.environ['COHERE'],
    model="embed-english-v3.0",  # Added this line
    user_agent="langchain-cohere-embeddings"
)

vector_store = InMemoryVectorStore(embedding=embeddings)

# Data - 1 and Data - 2
data_1 = open(r'data_1.txt', 'r').read()
data_2 = open(r'data_2.txt', 'r').read()
data_3 = open(r'data_3.txt', 'r').read()
data_4 = open(r'data_4.txt', 'r').read()

comb = open(r'comb.txt', 'r').read()

md_loader = UnstructuredMarkdownLoader('comb.md')

text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
# all_splits = text_splitter.split_text(data_1 + "\n\n" + data_2 + "\n\n" + data_3 + "\n\n" + data_4)
# all_splits = text_splitter.split_text(comb)
all_splits = text_splitter.split_documents(md_loader.load())

# docs = [Document(page_content=text) for text in all_splits]
docs = [Document(page_content=text.page_content, metadata=text.metadata) for text in all_splits]
_ = vector_store.add_documents(documents=docs)


prompt = hub.pull("rlm/rag-prompt")

# Replace with custom prompt
system_message = """You are a helpful and professional FAQ chatbot for the MLSC Coherence 25 Hackathon. Your role is to:
1. Provide accurate and concise answers based on the provided context
2. Be friendly but professional in tone
3. If you don't know the answer, simply say "I don't have information about that"
4. Keep responses brief and to the point
5. Focus on providing factual information from the context
6. Never mention "the provided context" or similar phrases in your responses
7. Never explain why you don't know something - just state that you don't know
8. Be direct and avoid unnecessary explanations"""

human_message_template = """Context: {context}

Question: {question}

Please provide a clear and concise answer based on the context above."""

class State(TypedDict):
    question: str
    context: List[Document]
    answer: str

def retrieve(state: State):
    retrieved_docs = vector_store.similarity_search(state["question"])
    return {"context": retrieved_docs}

def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = [
        SystemMessage(content=system_message),
        HumanMessage(content=human_message_template.format(
            context=docs_content,
            question=state["question"]
        ))
    ]
    print(messages)
    response = llm.invoke(messages)
    return {"answer": response.content}

graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()
'''
response = graph.invoke({"question": "Who should i contact for help ?"})
print(response["answer"])
'''

app = FastAPI()

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE"],
    allow_headers=["*"],
)

@app.get("/ping")
async def ping():
    return "Pong!"

class Query(BaseModel):
    question: str

@app.get("/chat")
async def chat(request: Query):
    response = graph.invoke({"question": request.question})
    response = response["answer"]
    response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
    # response = response[4:]
    return {"response": response}

@app.post("/chat")
async def chat(request: Query):
    response = graph.invoke({"question": request.question})
    response = response["answer"]
    response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
    # response = response[4:]
    return {"response": response}

@app.get("/")
async def root():
    return {"message": "Hello World"}