ESPnet
Afar
code
pra1223 commited on
Commit
317ffbb
·
verified ·
1 Parent(s): 9522d1a

Create code(4)

Browse files
Files changed (1) hide show
  1. code(4) +127 -0
code(4) ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import re
5
+
6
+ # --- Backend Functions ---
7
+
8
+ def initialize_gemini_api(api_key):
9
+ """Initializes the Gemini model and tokenizer."""
10
+ try:
11
+ # Using Auto Classes is generally recommended for loading from Hugging Face
12
+ tokenizer = AutoTokenizer.from_pretrained("google/gemini-1.5-pro-001", token=api_key) #check if model has a tokenizer and version number.
13
+ model = AutoModelForCausalLM.from_pretrained("google/gemini-1.5-pro-001", token=api_key, device_map="auto", torch_dtype=torch.bfloat16) #Added model device and dtype.
14
+
15
+ return model, tokenizer
16
+ except Exception as e:
17
+ st.error(f"Error initializing model: {e}")
18
+ return None, None
19
+
20
+
21
+ def preprocess_input(user_input, input_type):
22
+ """Preprocesses the input based on the input type."""
23
+
24
+ prompt_templates = {
25
+ "recipe_suggestion": "I have the following ingredients: {}. Suggest a recipe, and the recipe must include the ingredients I provided. Provide steps",
26
+ "promotion_idea": "Suggest a promotion to increase customer engagement based on these goals/themes: {}.",
27
+ "waste_reduction_tip": "Suggest strategies, including numbered steps, to minimize food waste based on this context/these ingredients: {}.",
28
+ "event_planning": "I want to plan an event. Here's the description/goals/requirements: {}. Give detailed, step-by-step instructions and important considerations.",
29
+ }
30
+
31
+ prompt = prompt_templates.get(input_type)
32
+ if prompt:
33
+ return prompt.format(user_input)
34
+ else:
35
+ return "Invalid input type." # Should ideally never happen due to Streamlit UI controls.
36
+
37
+ def generate_suggestion(model, tokenizer, processed_input):
38
+ """Generates text using the Gemini model."""
39
+
40
+ try:
41
+ input_ids = tokenizer(processed_input, return_tensors="pt").to(model.device) # Make sure tensors are on same device
42
+ outputs = model.generate(**input_ids, max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True) # Added important params for generation quality
43
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+ return generated_text
45
+ except Exception as e:
46
+ st.error(f"Error during generation: {e}")
47
+ return "An error occurred during suggestion generation."
48
+
49
+
50
+ def postprocess_output(raw_response, input_type):
51
+ """Postprocesses the generated text."""
52
+
53
+ # Remove any leading/trailing whitespace
54
+ cleaned_response = raw_response.strip()
55
+ # Further, specific postprocessing according to context
56
+ if input_type == 'recipe_suggestion':
57
+ try:
58
+ pass # Can add custom filtering
59
+ except:
60
+ pass
61
+
62
+ elif input_type == 'promotion_idea':
63
+ try:
64
+ pass #Can add custom regex and filters
65
+ except:
66
+ pass
67
+ elif input_type == "waste_reduction_tip" or input_type == 'event_planning':
68
+ try:
69
+ # Check to ensure instructions and steps in final output.
70
+ pass
71
+ except:
72
+ pass
73
+
74
+ # Basic example: Split into sentences for better readability (can be improved)
75
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', cleaned_response)
76
+ formatted_response = "\n\n".join(sentences)
77
+ return formatted_response
78
+ def get_ai_suggestion(user_input, input_type, api_key):
79
+ model, tokenizer = initialize_gemini_api(api_key)
80
+ if model is None or tokenizer is None:
81
+ return "Failed to initialize the model. Check your API key."
82
+ processed_input = preprocess_input(user_input, input_type)
83
+ raw_response = generate_suggestion(model, tokenizer, processed_input)
84
+ formatted_response = postprocess_output(raw_response, input_type)
85
+ return formatted_response
86
+
87
+
88
+
89
+ # --- Streamlit Frontend ---
90
+
91
+ st.set_page_config(page_title="AI Restaurant Assistant", layout="wide") #Set page config
92
+ st.sidebar.title("AI Restaurant Assistant")
93
+
94
+ # --- API KEY HANDLING ---
95
+
96
+ # Use st.session_state to persist the API key *only for the session*
97
+ if 'api_key' not in st.session_state:
98
+ st.session_state.api_key = ''
99
+ # IMPORTANT SECURITY NOTE: This method is suitable for demonstration/local development.
100
+ # For a production deployment, you MUST use a more secure method of storing the API key,
101
+ # such as environment variables and NEVER hardcode it or commit it to version control.
102
+ api_key_input = st.sidebar.text_input("Enter your Hugging Face API key:", type="password", value=st.session_state.api_key)
103
+
104
+ if api_key_input:
105
+ st.session_state.api_key = api_key_input #Value is automatically cached and input bar has api_key once entered.
106
+
107
+ if not st.session_state.api_key:
108
+ st.sidebar.warning("Please enter your Hugging Face API key to use the application.")
109
+ st.stop() # Stop execution if no API key
110
+
111
+ # --- Input Selection ---
112
+ input_type = st.sidebar.selectbox("What kind of suggestion do you need?",
113
+ ["recipe_suggestion", "promotion_idea", "waste_reduction_tip", "event_planning"])
114
+
115
+ # --- Main Area ---
116
+ st.title("Get AI-Powered Suggestions")
117
+ st.write("This tool leverages the power of the Gemini 1.5 Pro model to assist with various restaurant management tasks.") # Introduction and description
118
+ user_input = st.text_area("Enter your input here:", height=150, key="user_input") #Key is added
119
+
120
+ if st.button("Generate Suggestion"):
121
+ if user_input:
122
+ with st.spinner("Generating suggestion..."):
123
+ suggestion = get_ai_suggestion(user_input, input_type, st.session_state.api_key)
124
+ st.markdown("### AI Suggestion:", unsafe_allow_html=True) #Style output and enhance it visually.
125
+ st.write(suggestion) #Can upgrade output design by having boxes etc.
126
+ else:
127
+ st.warning("Please enter some input.")