|
|
import os
|
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
|
|
|
os.environ['HF_HOME'] = '/tmp/huggingface'
|
|
|
|
|
|
from flask import Flask, request, jsonify
|
|
|
from flask_cors import CORS
|
|
|
from transformers import pipeline
|
|
|
import re
|
|
|
import logging
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
CORS(app)
|
|
|
|
|
|
|
|
|
MODEL_PATH = os.environ.get('MODEL_PATH', 'samdak93/qrit-2')
|
|
|
generator = None
|
|
|
|
|
|
def load_model():
|
|
|
global generator
|
|
|
logger.info(f"Loading model from {MODEL_PATH}")
|
|
|
try:
|
|
|
|
|
|
generator = pipeline("text-generation", model=MODEL_PATH)
|
|
|
logger.info("Model loaded successfully")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading custom model: {e}")
|
|
|
logger.info("Falling back to default GPT-2")
|
|
|
|
|
|
generator = pipeline("text-generation", model="gpt2")
|
|
|
logger.info("Default GPT-2 model loaded")
|
|
|
|
|
|
def generate_recipe(goal_prompt, max_length=300, temperature=0.7):
|
|
|
global generator
|
|
|
if generator is None:
|
|
|
load_model()
|
|
|
|
|
|
full_prompt = f"### Goal: {goal_prompt}\n### Recipe: "
|
|
|
|
|
|
|
|
|
try:
|
|
|
generated_text = generator(full_prompt, max_length=max_length, temperature=temperature, num_return_sequences=1)[0]['generated_text']
|
|
|
recipe_part = generated_text[len(full_prompt):]
|
|
|
if "### Goal:" in recipe_part:
|
|
|
recipe_part = recipe_part.split("### Goal:")[0].strip()
|
|
|
|
|
|
return full_prompt + recipe_part
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error generating recipe: {str(e)}")
|
|
|
return str(e)
|
|
|
|
|
|
def parse_recipe(recipe_text):
|
|
|
recipe_data = {
|
|
|
"title": "",
|
|
|
"ingredients": [],
|
|
|
"instructions": [],
|
|
|
"nutrition": {}
|
|
|
}
|
|
|
|
|
|
title_match = re.search(r"Title: (.*?)(?:\n|$)", recipe_text)
|
|
|
if title_match:
|
|
|
recipe_data["title"] = title_match.group(1).strip()
|
|
|
|
|
|
ingredients_section = re.search(r"Ingredients:(.*?)Instructions:", recipe_text, re.DOTALL)
|
|
|
if ingredients_section:
|
|
|
ingredients_text = ingredients_section.group(1).strip()
|
|
|
ingredients_list = [item.strip().lstrip('- ') for item in ingredients_text.split('\n') if item.strip()]
|
|
|
recipe_data["ingredients"] = ingredients_list
|
|
|
|
|
|
instructions_section = re.search(r"Instructions:(.*?)(?:Nutrition:|$)", recipe_text, re.DOTALL)
|
|
|
if instructions_section:
|
|
|
instructions_text = instructions_section.group(1).strip()
|
|
|
instructions_list = []
|
|
|
for line in instructions_text.split('\n'):
|
|
|
line = line.strip()
|
|
|
if line:
|
|
|
line = re.sub(r"^\d+\.\s", "", line)
|
|
|
instructions_list.append(line.strip())
|
|
|
recipe_data["instructions"] = instructions_list
|
|
|
|
|
|
nutrition_section = re.search(r"Nutrition: (.*?)(?:\n|$)", recipe_text)
|
|
|
if nutrition_section:
|
|
|
nutrition_text = nutrition_section.group(1).strip()
|
|
|
for label in ["calories", "sugar", "protein", "fat"]:
|
|
|
match = re.search(rf"(\d+)\s*(?:g\s*)?{label}", nutrition_text)
|
|
|
if match:
|
|
|
recipe_data["nutrition"][label] = int(match.group(1))
|
|
|
|
|
|
return recipe_data
|
|
|
|
|
|
@app.route('/', methods=['GET'])
|
|
|
def home():
|
|
|
return jsonify({
|
|
|
"name": "Recipe Generation API",
|
|
|
"version": "1.0.0",
|
|
|
"endpoints": {
|
|
|
"/api/healthcheck": "Check API status",
|
|
|
"/api/generate": "Generate a recipe",
|
|
|
"/api/parse": "Parse a recipe text"
|
|
|
}
|
|
|
})
|
|
|
|
|
|
@app.route('/api/healthcheck', methods=['GET'])
|
|
|
def healthcheck():
|
|
|
return jsonify({
|
|
|
"status": "ok",
|
|
|
"message": "Recipe API is running",
|
|
|
"timestamp": str(datetime.now())
|
|
|
})
|
|
|
|
|
|
@app.route('/api/generate', methods=['POST'])
|
|
|
def api_generate_recipe():
|
|
|
data = request.json
|
|
|
if not data or 'goal' not in data:
|
|
|
return jsonify({"error": "Missing 'goal' parameter"}), 400
|
|
|
|
|
|
try:
|
|
|
goal = data['goal']
|
|
|
max_length = data.get('max_length', 300)
|
|
|
temperature = data.get('temperature', 0.7)
|
|
|
recipe_text = generate_recipe(goal, max_length, temperature)
|
|
|
return jsonify({"goal": goal, "recipe_text": recipe_text})
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error generating recipe: {str(e)}")
|
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
@app.route('/api/parse', methods=['POST'])
|
|
|
def api_parse_recipe():
|
|
|
data = request.json
|
|
|
if not data or 'recipe_text' not in data:
|
|
|
return jsonify({"error": "Missing 'recipe_text' parameter"}), 400
|
|
|
try:
|
|
|
return jsonify(parse_recipe(data['recipe_text']))
|
|
|
except Exception as e:
|
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
|
|
|
|
|
load_model()
|
|
|
port = int(os.environ.get('PORT', 7860))
|
|
|
app.run(host='0.0.0.0', port=port) |