Mohansai2004 commited on
Commit
04da277
·
1 Parent(s): e7dde2b

feat: switch to DistilGPT-2 text generation

Browse files

- Replace image generation with text generation
- Add DistilGPT-2 model for better CPU performance
- Add temperature and length controls
- Update documentation and requirements
- Optimize for CPU usage
- Add copy text functionality

Files changed (3) hide show
  1. README.md +15 -15
  2. app.py +110 -51
  3. requirements.txt +3 -2
README.md CHANGED
@@ -1,25 +1,25 @@
1
  ---
2
- title: AI Text Generator
3
- emoji: 💬
4
- colorFrom: blue
5
- colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
10
- short_description: Fast text generation using DistilGPT-2
11
  ---
12
 
13
- # AI Text Generator
14
- Quick text generation using DistilGPT-2
15
 
16
  ## Features
17
- - Fast text generation
18
- - CPU-optimized performance
19
- - Adjustable creativity settings
20
- - Memory efficient (< 2GB)
21
 
22
- ## Usage Tips
23
- - Clear, specific prompts work best
24
- - Adjust temperature for different styles
25
- - Experiment with prompt formats
 
1
  ---
2
+ title: Health Diet Planner
3
+ emoji: 🏥
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: AI-powered health diet recommendations
11
  ---
12
 
13
+ # Health Diet Planner
14
+ AI-assisted diet planning based on health conditions
15
 
16
  ## Features
17
+ - Personalized diet recommendations
18
+ - Health condition specific advice
19
+ - Dietary restriction support
20
+ - Downloadable diet plans
21
 
22
+ ## Important Notice
23
+ - For reference purposes only
24
+ - Consult healthcare professionals
25
+ - Not a substitute for medical advice
app.py CHANGED
@@ -2,72 +2,131 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import gc
 
5
 
6
  @st.cache_resource
7
  def load_model():
8
- # Load DistilGPT-2
9
- model_id = "distilgpt2"
10
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_id,
13
  torch_dtype=torch.float32,
14
- low_cpu_mem_usage=True
15
- ).to("cpu")
 
 
16
 
17
- # Set threading and memory optimizations
18
- torch.set_num_threads(4)
19
- gc.collect()
20
 
 
21
  return model, tokenizer
22
 
23
- st.title("💬 AI Text Generator")
24
- st.write("Generate creative text using DistilGPT-2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Initialize model
27
- model, tokenizer = load_model()
28
 
29
- # User input
30
- prompt = st.text_area(
31
- "Enter your prompt:",
32
- "Once upon a time in a digital world,"
33
- )
34
 
35
- # Generation settings
36
- with st.sidebar:
37
- max_length = st.slider("Max Length", 50, 200, 100)
38
- temperature = st.slider("Temperature", 0.1, 1.0, 0.7)
 
39
 
40
- if st.button("Generate Text"):
41
- with st.spinner("Generating text..."):
42
- try:
43
- # Tokenize and generate
44
- inputs = tokenizer(prompt, return_tensors="pt")
45
-
46
- with torch.inference_mode():
47
- outputs = model.generate(
48
- inputs["input_ids"],
49
- max_length=max_length,
50
- temperature=temperature,
51
- num_return_sequences=1,
52
- pad_token_id=tokenizer.eos_token_id,
53
- do_sample=True,
 
54
  )
55
-
56
- # Decode and display
57
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
58
- st.write("### Generated Text:")
59
- st.write(generated_text)
60
-
61
- # Add copy button
62
- st.button("📋 Copy Text",
63
- on_click=lambda: st.write(generated_text))
64
-
65
- except Exception as e:
66
- st.error(f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
67
 
68
  st.markdown("""
69
- ### Tips for better results:
70
- - Start with clear, well-structured prompts
71
- - Adjust temperature for creativity (higher) or consistency (lower)
72
- - Try different prompt styles for different outputs
 
73
  """)
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import gc
5
+ from functools import lru_cache
6
 
7
  @st.cache_resource
8
  def load_model():
9
+ # Force CPU and optimize memory
10
+ torch.set_num_threads(6) # Increased threads for better CPU utilization
11
+ torch.set_grad_enabled(False) # Disable gradients
12
+
13
+ model_id = "microsoft/BioGPT-Large"
14
+
15
+ # Load tokenizer with caching
16
+ tokenizer = AutoTokenizer.from_pretrained(
17
+ model_id,
18
+ model_max_length=512,
19
+ padding_side='left'
20
+ )
21
+
22
+ # Optimized model loading
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_id,
25
  torch_dtype=torch.float32,
26
+ low_cpu_mem_usage=True,
27
+ device_map='cpu',
28
+ max_memory={'cpu': '16GB'}
29
+ )
30
 
31
+ # Model optimizations
32
+ model.eval() # Set to evaluation mode
33
+ model = torch.compile(model) # Use torch compile for faster inference
34
 
35
+ gc.collect()
36
  return model, tokenizer
37
 
38
+ @lru_cache(maxsize=32) # Cache recent generations
39
+ def generate_diet_plan(health_condition: str, dietary_restrictions: str, cache_key: str):
40
+ model, tokenizer = load_model()
41
+
42
+ prompt = f"""
43
+ Provide a concise diet plan for {health_condition} with {dietary_restrictions} restrictions.
44
+ Key points:
45
+ 1. Recommended foods
46
+ 2. Foods to avoid
47
+ 3. Meal schedule
48
+ """
49
+
50
+ try:
51
+ with torch.inference_mode():
52
+ inputs = tokenizer(
53
+ prompt,
54
+ return_tensors="pt",
55
+ max_length=256, # Reduced for faster processing
56
+ truncation=True,
57
+ padding=True
58
+ )
59
+
60
+ outputs = model.generate(
61
+ inputs["input_ids"],
62
+ max_length=512, # Reduced length
63
+ temperature=0.7,
64
+ num_return_sequences=1,
65
+ pad_token_id=tokenizer.eos_token_id,
66
+ do_sample=True,
67
+ top_p=0.9,
68
+ repetition_penalty=1.2
69
+ )
70
+
71
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
72
+ finally:
73
+ # Clean up
74
+ gc.collect()
75
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
76
 
77
+ # Streamlit interface
78
+ st.title("🏥 Health Diet Planner")
79
 
80
+ st.write("Fast diet recommendations based on health conditions")
 
 
 
 
81
 
82
+ # User inputs with validation
83
+ health_condition = st.text_input(
84
+ "Health Condition",
85
+ placeholder="e.g., Diabetes, Hypertension"
86
+ ).strip()
87
 
88
+ dietary_restrictions = st.text_input(
89
+ "Dietary Restrictions",
90
+ placeholder="e.g., Vegetarian, No nuts"
91
+ ).strip()
92
+
93
+ if st.button("Generate Diet Plan"):
94
+ if health_condition:
95
+ with st.spinner("Generating plan (typically 15-30 seconds)..."):
96
+ try:
97
+ # Create cache key
98
+ cache_key = f"{health_condition}_{dietary_restrictions}"
99
+ diet_plan = generate_diet_plan(
100
+ health_condition,
101
+ dietary_restrictions,
102
+ cache_key
103
  )
104
+
105
+ st.markdown(f"### Diet Plan for {health_condition}")
106
+ st.markdown(diet_plan)
107
+
108
+ # Download option
109
+ st.download_button(
110
+ "💾 Download Plan",
111
+ diet_plan,
112
+ file_name="diet_plan.txt"
113
+ )
114
+
115
+ except Exception as e:
116
+ st.error("Error generating plan. Try simpler input.")
117
+ else:
118
+ st.warning("Please enter a health condition")
119
+
120
+ # Clear cache button in sidebar
121
+ if st.sidebar.button("Clear Cache"):
122
+ generate_diet_plan.cache_clear()
123
+ st.cache_resource.clear()
124
+ st.success("Cache cleared!")
125
 
126
  st.markdown("""
127
+ ### Important Notes:
128
+ - This is an AI-generated diet plan for reference only
129
+ - Always consult healthcare professionals before making dietary changes
130
+ - Individual needs may vary
131
+ - Update your health condition details for more accurate recommendations
132
  """)
requirements.txt CHANGED
@@ -2,5 +2,6 @@
2
  # streamlit is already pre-installed
3
  streamlit
4
  torch
5
- transformers
6
- accelerate
 
 
2
  # streamlit is already pre-installed
3
  streamlit
4
  torch
5
+ transformers>=4.34.0
6
+ accelerate
7
+ scikit-learn