Sefat33 commited on
Commit
2f288b6
·
verified ·
1 Parent(s): 2bab570

Update guide_model.py

Browse files
Files changed (1) hide show
  1. guide_model.py +11 -29
guide_model.py CHANGED
@@ -1,47 +1,29 @@
1
  import os
2
 
3
- # Set cache-related env vars BEFORE importing transformers
4
- cache_dir = "/tmp/hf_cache"
5
- os.environ["HF_HOME"] = cache_dir
6
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
7
- os.environ["HF_DATASETS_CACHE"] = cache_dir
8
- os.environ["HF_METRICS_CACHE"] = cache_dir
9
-
10
- # Ensure cache directory exists; if permission denied, fallback to ./hf_cache
11
- try:
12
- os.makedirs(cache_dir, exist_ok=True)
13
- except PermissionError:
14
- cache_dir = "./hf_cache"
15
- os.makedirs(cache_dir, exist_ok=True)
16
- os.environ["HF_HOME"] = cache_dir
17
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
18
- os.environ["HF_DATASETS_CACHE"] = cache_dir
19
- os.environ["HF_METRICS_CACHE"] = cache_dir
20
 
 
21
  from transformers import pipeline
22
 
23
- # Load the text generation pipeline once
 
24
  try:
25
- generator = pipeline(
26
- "text-generation",
27
- model="gpt2",
28
- cache_dir=cache_dir
29
- )
30
  except Exception as e:
31
  generator = None
32
  print("⚠️ Failed to load model:", e)
33
 
34
-
35
  def generate_description(country_name):
36
  if not generator:
37
  return "⚠️ Model is not available. Please check the server logs."
38
 
 
39
  try:
40
- prompt = f"Tell me about {country_name}."
41
  result = generator(prompt, max_length=100, do_sample=True)
42
- if result and isinstance(result, list):
43
- return result[0]["generated_text"].strip()
44
- else:
45
- return "⚠️ Failed to generate text."
46
  except Exception as e:
47
  return f"⚠️ Error generating description: {str(e)}"
 
1
  import os
2
 
3
+ # Set cache environment variables BEFORE importing transformers
4
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
5
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
6
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
7
+ os.environ["HF_METRICS_CACHE"] = "/tmp/hf_cache"
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ import transformers
10
  from transformers import pipeline
11
 
12
+ os.makedirs("/tmp/hf_cache", exist_ok=True)
13
+
14
  try:
15
+ generator = pipeline("text-generation", model="gpt2", cache_dir="/tmp/hf_cache")
 
 
 
 
16
  except Exception as e:
17
  generator = None
18
  print("⚠️ Failed to load model:", e)
19
 
 
20
  def generate_description(country_name):
21
  if not generator:
22
  return "⚠️ Model is not available. Please check the server logs."
23
 
24
+ prompt = f"Tell me about {country_name}."
25
  try:
 
26
  result = generator(prompt, max_length=100, do_sample=True)
27
+ return result[0]["generated_text"].strip()
 
 
 
28
  except Exception as e:
29
  return f"⚠️ Error generating description: {str(e)}"