sushmapiraka commited on
Commit
5cd6d55
·
verified ·
1 Parent(s): 59a4324

Upload 3 files

Browse files
Files changed (3) hide show
  1. main.py +38 -0
  2. models_api.py +57 -0
  3. requirements.txt +48 -0
main.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from models_api import get_answer, get_hugging_face_answer
3
+
4
+ # Define Streamlit app
5
+ def main():
6
+ st.title("Document Analysis Tool")
7
+
8
+ # File upload
9
+ uploaded_file = st.file_uploader("Upload a text file", type=["txt"])
10
+
11
+ if uploaded_file:
12
+ file_contents = uploaded_file.read().decode("utf-8")
13
+ st.text_area("File Content", file_contents, height=200)
14
+
15
+ # User input question
16
+ question = st.text_input("Enter your question")
17
+
18
+ if st.button("Get Answer"):
19
+ if question:
20
+ # Call API functions based on user choice
21
+ st.write("Response from Llama3 8b:")
22
+ answer_llama3 = get_answer("llama3", file_contents, question)
23
+ # answer_llama3 = get_hugging_face_answer("meta-llama/Meta-Llama-3-8B-Instruct", file_contents, question)
24
+ st.write(answer_llama3)
25
+
26
+ st.write("Response from Mistral 7b:")
27
+ answer_mistral7b = get_answer("mistral", file_contents, question)
28
+ # answer_mistral7b = get_hugging_face_answer("mistralai/Mistral-7B-Instruct-v0.1", file_contents, question)
29
+ st.write(answer_mistral7b)
30
+
31
+ st.write("Response from Gemma 7b:")
32
+ # answer_gemma7b = get_answer("gemma", file_contents, question)
33
+ answer_gemma7b = get_hugging_face_answer("google/gemma-7b-it", file_contents, question)
34
+ st.write(answer_gemma7b)
35
+
36
+ # Run Streamlit app
37
+ if __name__ == "__main__":
38
+ main()
models_api.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+ def get_answer(model_name, context, question):
8
+ llm_key = os.getenv("llm_key")
9
+ url = os.getenv("main_url")
10
+ # Construct the prompt for the model
11
+ prompt = f"You are a Question Answering Model. Can you help me answer the question: {question} from the context: {context}? Just return the answer only. The document may contain some Arabic text; please translate that to English if needed."
12
+
13
+ # Prepare payload for API request
14
+ payload = {
15
+ "model": model_name,
16
+ "messages": [
17
+ {
18
+ "role": "user",
19
+ "content": prompt
20
+ }
21
+ ],
22
+ "max_tokens": 300,
23
+ "temperature": 0.2
24
+ }
25
+
26
+ headers = {
27
+ 'Authorization': f'Bearer {llm_key}',
28
+ 'Content-Type': 'application/json'
29
+ }
30
+
31
+ # Convert payload to JSON string
32
+ json_payload = json.dumps(payload)
33
+
34
+ try:
35
+ # Send POST request to the API
36
+ response = requests.post(url, headers=headers, data=json_payload)
37
+
38
+ # Check if request was successful
39
+ if response.status_code == 200:
40
+ response_data = response.json() # Parse response JSON
41
+ answer = response_data['choices'][0]['message']['content'] # Extract model's answer from response
42
+ return answer
43
+ else:
44
+ print(f"Request failed with status code: {response.status_code}")
45
+ return None
46
+
47
+ except requests.exceptions.RequestException as e:
48
+ print(f"Error occurred: {e}")
49
+ return None
50
+
51
+ from huggingface_hub import InferenceClient
52
+
53
+ def get_hugging_face_answer(model_name, context, question):
54
+ client = InferenceClient(model_name, token=os.getenv("HF_TOKEN"))
55
+ prompt = f"You are a Question Answering Model. Can you help me answer the question: {question} from the context: {context}? Just return the answer only. The document may contain some Arabic text; please translate that to English if needed."
56
+ output = client.text_generation(prompt , max_new_tokens = 200, stream=True, temperature=0.1)
57
+ return output
requirements.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.3.0
2
+ attrs==23.2.0
3
+ blinker==1.8.1
4
+ cachetools==5.3.3
5
+ certifi==2024.2.2
6
+ charset-normalizer==3.3.2
7
+ click==8.1.7
8
+ colorama==0.4.6
9
+ filelock==3.14.0
10
+ fsspec==2024.3.1
11
+ gitdb==4.0.11
12
+ GitPython==3.1.43
13
+ huggingface-hub==0.23.0
14
+ idna==3.7
15
+ Jinja2==3.1.4
16
+ jsonschema==4.22.0
17
+ jsonschema-specifications==2023.12.1
18
+ markdown-it-py==3.0.0
19
+ MarkupSafe==2.1.5
20
+ mdurl==0.1.2
21
+ numpy==1.26.4
22
+ packaging==24.0
23
+ pandas==2.2.2
24
+ pillow==10.3.0
25
+ protobuf==4.25.3
26
+ pyarrow==16.0.0
27
+ pydeck==0.9.0
28
+ Pygments==2.18.0
29
+ python-dateutil==2.9.0.post0
30
+ python-dotenv==1.0.1
31
+ pytz==2024.1
32
+ PyYAML==6.0.1
33
+ referencing==0.35.1
34
+ requests==2.31.0
35
+ rich==13.7.1
36
+ rpds-py==0.18.0
37
+ six==1.16.0
38
+ smmap==5.0.1
39
+ streamlit==1.34.0
40
+ tenacity==8.2.3
41
+ toml==0.10.2
42
+ toolz==0.12.1
43
+ tornado==6.4
44
+ tqdm==4.66.4
45
+ typing_extensions==4.11.0
46
+ tzdata==2024.1
47
+ urllib3==2.2.1
48
+ watchdog==4.0.0