recipe-gen-api / app.py
Chama99's picture
Update app.py
0fba253 verified
from flask import Flask, request, jsonify
from transformers import T5ForConditionalGeneration, T5Tokenizer
from peft import PeftModel
import torch
app = Flask(__name__)
# Load base model and tokenizer
base_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "Chama99/flan-t5-small-recipe-generator")
def generate_recipe(ingredients):
# Improved prompt for better structure
prompt = f"Generate a complete recipe using these ingredients: {ingredients}. Include ingredients list with quantities and detailed step-by-step cooking instructions:"
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=800, # Increased for longer recipes
num_beams=3, # Reduced for more creativity
temperature=0.7, # Lower for more coherent output
do_sample=True,
repetition_penalty=1.2, # Lower penalty
no_repeat_ngram_size=2, # Reduced
pad_token_id=tokenizer.pad_token_id,
early_stopping=True
)
recipe = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_recipe = recipe.replace(prompt, "").strip()
# Post-process to ensure minimum quality
if len(generated_recipe) < 50: # If too short, try again with different params
outputs = model.generate(
**inputs,
max_length=600,
num_beams=2,
temperature=0.8,
do_sample=True,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id
)
recipe = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_recipe = recipe.replace(prompt, "").strip()
return generated_recipe
@app.route('/generate_recipe', methods=['POST'])
def generate_recipe_endpoint():
data = request.json
ingredients = data.get('ingredients', '')
if not ingredients:
return jsonify({"error": "Please provide ingredients"}), 400
recipe = generate_recipe(ingredients)
return jsonify({"recipe": recipe})
# Add a health check endpoint
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({"status": "healthy", "message": "Recipe generator is running"})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)