abhlash commited on
Commit
924bd16
·
1 Parent(s): 0df4ea2

updated model

Browse files
Files changed (1) hide show
  1. app.py +37 -10
app.py CHANGED
@@ -1,25 +1,52 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import os
4
  from dotenv import load_dotenv
5
  import logging
6
  import sys # Ensure sys is imported
7
- from huggingface_hub import login
8
 
9
  # Load environment variables
10
  load_dotenv()
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stdout)
12
 
13
  # Authenticate with Hugging Face
14
- hf_token = os.environ.get("HUGGING_FACE_TOKEN")
15
- if not hf_token:
16
- raise ValueError("HUGGING_FACE_TOKEN not found in environment variables")
17
- login(token=hf_token)
18
-
19
- # Load the Llama-3.1-8B model and tokenizer
20
  model_name = "meta-llama/Llama-3.1-8B"
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Function to generate a formatted email
25
  def generate_email(recipient_name, recipient_email, industry, recipient_role, details):
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
3
  import os
4
  from dotenv import load_dotenv
5
  import logging
6
  import sys # Ensure sys is imported
7
+ from huggingface_hub import login, HfApi
8
 
9
  # Load environment variables
10
  load_dotenv()
11
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stdout)
12
 
13
  # Authenticate with Hugging Face
14
+ hf_token = os.environ.get("HUGGINGFACE_TOKEN")
 
 
 
 
 
15
  model_name = "meta-llama/Llama-3.1-8B"
16
+ fallback_model = "facebook/opt-350m"
17
+
18
+ if hf_token:
19
+ try:
20
+ login(token=hf_token)
21
+ api = HfApi()
22
+ api.whoami()
23
+ logging.info("Successfully logged in to Hugging Face")
24
+ except Exception as e:
25
+ logging.error(f"Error authenticating with Hugging Face: {str(e)}")
26
+ logging.warning("Proceeding without authentication. Will use fallback model.")
27
+ model_name = fallback_model
28
+ else:
29
+ logging.warning("HUGGINGFACE_TOKEN not found in environment variables. Proceeding without authentication.")
30
+ model_name = fallback_model
31
+
32
+ # Load the model and tokenizer
33
+ try:
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+
36
+ # Custom configuration to handle the RoPE scaling issue
37
+ if model_name == "meta-llama/Llama-3.1-8B":
38
+ config = LlamaConfig.from_pretrained(model_name)
39
+ config.rope_scaling = {"type": "linear", "factor": 8.0} # Adjust as needed
40
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
41
+ else:
42
+ model = AutoModelForCausalLM.from_pretrained(model_name)
43
+
44
+ logging.info(f"Successfully loaded {model_name}")
45
+ except Exception as e:
46
+ logging.error(f"Error loading {model_name}: {str(e)}")
47
+ logging.info(f"Falling back to {fallback_model}")
48
+ tokenizer = AutoTokenizer.from_pretrained(fallback_model)
49
+ model = AutoModelForCausalLM.from_pretrained(fallback_model)
50
 
51
  # Function to generate a formatted email
52
  def generate_email(recipient_name, recipient_email, industry, recipient_role, details):